VoxSIM / predict.py
junseok
new commit
ce904ba
raw
history blame
4.76 kB
import argparse
import pathlib
import tqdm
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy
from score import Score
import torch
import warnings
warnings.filterwarnings("ignore")
def get_arg():
parser = argparse.ArgumentParser()
parser.add_argument("--bs", required=False, default=None, type=int)
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path)
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path)
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path)
parser.add_argument("--out_path", required=True, type=pathlib.Path)
parser.add_argument("--num_workers", required=False, default=0, type=int)
return parser.parse_args()
def loadWav(filename, max_frames: int = 400):
# Maximum audio length
max_audio = max_frames * 160 + 240
# Read wav file and convert to torch tensor
if type(filename) == tuple:
sr, audio = filename
audio = librosa.util.normalize(audio)
print(numpy.linalg.norm(audio))
else:
audio, sr = librosa.load(filename, sr=16000)
audio_org = audio.copy()
audiosize = audio.shape[0]
if audiosize <= max_audio:
shortage = max_audio - audiosize + 1
audio = numpy.pad(audio, (0, shortage), 'wrap')
audiosize = audio.shape[0]
startframe = numpy.linspace(0,audiosize-max_audio,num=10)
feats = []
for asf in startframe:
feats.append(audio[int(asf):int(asf)+max_audio])
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
class AudioDataset(Dataset):
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
self.inp_wavlist = list(inp_dir_path.glob("*.wav"))
self.ref_wavlist = list(ref_dir_path.glob("*.wav"))
assert len(self.inp_wavlist) == len(self.ref_wavlist)
self.inp_wavlist.sort()
self.ref_wavlist.sort()
_, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
self.max_audio = max_frames * 160 + 240
def __len__(self):
return len(self.inp_wavlist)
def __getitem__(self, idx):
inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx])
return inp_wavs, inp_wav, ref_wavs, ref_wav
def main():
args = get_arg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.mode == "predict_file":
assert args.inp_path is not None
assert args.ref_path is not None
assert args.inp_dir is None
assert args.ref_dir is None
assert args.inp_path.exists()
assert args.inp_path.is_file()
assert args.ref_path.exists()
assert args.ref_path.is_file()
inp_wavs, inp_wav = loadWav(args.inp_path)
ref_wavs, ref_wav = loadWav(args.ref_path)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
print("Voxsim score: ", score[0])
with open(args.out_path, "w") as fw:
fw.write(str(score[0]))
else:
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
assert args.bs is not None, "bs is required when mode is predict_dir."
assert args.inp_path is None, "inp_path should be None"
assert args.ref_path is None, "ref_path should be None"
assert args.inp_dir.exists()
assert args.ref_dir.exists()
assert args.inp_dir.is_dir()
assert args.ref_dir.is_dir()
dataset = AudioDataset(args.inp_dir, args.ref_dir)
loader = DataLoader(
dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers)
scorer = Score(ckpt_path=args.ckpt_path, device=device)
with open(args.out_path, 'w'):
pass
for batch in tqdm.tqdm(loader):
scores = score.score(batch.to(device))
with open(args.out_path, 'a') as fw:
for s in scores:
fw.write(str(s) + "\n")
print("save to ", args.out_path)
if __name__ == "__main__":
main()