import argparse import pathlib import tqdm from torch.utils.data import Dataset, DataLoader import librosa import numpy from score import Score import torch import warnings warnings.filterwarnings("ignore") def get_arg(): parser = argparse.ArgumentParser() parser.add_argument("--bs", required=False, default=None, type=int) parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str) parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path) parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path) parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path) parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path) parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path) parser.add_argument("--out_path", required=True, type=pathlib.Path) parser.add_argument("--num_workers", required=False, default=0, type=int) return parser.parse_args() def loadWav(filename, max_frames: int = 400): # Maximum audio length max_audio = max_frames * 160 + 240 # Read wav file and convert to torch tensor if type(filename) == tuple: sr, audio = filename audio = librosa.util.normalize(audio) else: audio, sr = librosa.load(filename, sr=16000) audio_org = audio.copy() audiosize = audio.shape[0] if audiosize <= max_audio: shortage = max_audio - audiosize + 1 audio = numpy.pad(audio, (0, shortage), 'wrap') audiosize = audio.shape[0] startframe = numpy.linspace(0,audiosize-max_audio,num=10) feats = [] for asf in startframe: feats.append(audio[int(asf):int(asf)+max_audio]) feat = numpy.stack(feats,axis=0).astype(numpy.float32) return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32)) class AudioDataset(Dataset): def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400): self.inp_wavlist = list(inp_dir_path.glob("*.wav")) self.ref_wavlist = list(ref_dir_path.glob("*.wav")) assert len(self.inp_wavlist) == len(self.ref_wavlist) self.inp_wavlist.sort() self.ref_wavlist.sort() _, self.sr = librosa.load(self.inp_wavlist[0], sr=None) self.max_audio = max_frames * 160 + 240 def __len__(self): return len(self.inp_wavlist) def __getitem__(self, idx): inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx]) ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx]) return inp_wavs, inp_wav, ref_wavs, ref_wav def main(): args = get_arg() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.mode == "predict_file": assert args.inp_path is not None assert args.ref_path is not None assert args.inp_dir is None assert args.ref_dir is None assert args.inp_path.exists() assert args.inp_path.is_file() assert args.ref_path.exists() assert args.ref_path.is_file() inp_wavs, inp_wav = loadWav(args.inp_path) ref_wavs, ref_wav = loadWav(args.ref_path) scorer = Score(ckpt_path=args.ckpt_path, device=device) score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) print("Voxsim score: ", score[0]) with open(args.out_path, "w") as fw: fw.write(str(score[0])) else: assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir." assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir." assert args.bs is not None, "bs is required when mode is predict_dir." assert args.inp_path is None, "inp_path should be None" assert args.ref_path is None, "ref_path should be None" assert args.inp_dir.exists() assert args.ref_dir.exists() assert args.inp_dir.is_dir() assert args.ref_dir.is_dir() dataset = AudioDataset(args.inp_dir, args.ref_dir) loader = DataLoader( dataset, batch_size=args.bs, shuffle=False, num_workers=args.num_workers) scorer = Score(ckpt_path=args.ckpt_path, device=device) with open(args.out_path, 'w'): pass for batch in tqdm.tqdm(loader): scores = score.score(batch.to(device)) with open(args.out_path, 'a') as fw: for s in scores: fw.write(str(s) + "\n") print("save to ", args.out_path) if __name__ == "__main__": main()