Spaces:
Running
Running
import argparse | |
import pathlib | |
import tqdm | |
from torch.utils.data import Dataset, DataLoader | |
import librosa | |
import numpy | |
from score import Score | |
import torch | |
import warnings | |
warnings.filterwarnings("ignore") | |
def get_arg(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--bs", required=False, default=None, type=int) | |
parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str) | |
parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path) | |
parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path) | |
parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path) | |
parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path) | |
parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path) | |
parser.add_argument("--out_path", required=True, type=pathlib.Path) | |
parser.add_argument("--num_workers", required=False, default=0, type=int) | |
return parser.parse_args() | |
def loadWav(filename, max_frames: int = 400): | |
# Maximum audio length | |
max_audio = max_frames * 160 + 240 | |
# Read wav file and convert to torch tensor | |
if type(filename) == tuple: | |
sr, audio = filename | |
audio = librosa.util.normalize(audio) | |
print(numpy.linalg.norm(audio)) | |
else: | |
audio, sr = librosa.load(filename, sr=16000) | |
audio_org = audio.copy() | |
audiosize = audio.shape[0] | |
if audiosize <= max_audio: | |
shortage = max_audio - audiosize + 1 | |
audio = numpy.pad(audio, (0, shortage), 'wrap') | |
audiosize = audio.shape[0] | |
startframe = numpy.linspace(0,audiosize-max_audio,num=10) | |
feats = [] | |
for asf in startframe: | |
feats.append(audio[int(asf):int(asf)+max_audio]) | |
feat = numpy.stack(feats,axis=0).astype(numpy.float32) | |
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32)) | |
class AudioDataset(Dataset): | |
def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400): | |
self.inp_wavlist = list(inp_dir_path.glob("*.wav")) | |
self.ref_wavlist = list(ref_dir_path.glob("*.wav")) | |
assert len(self.inp_wavlist) == len(self.ref_wavlist) | |
self.inp_wavlist.sort() | |
self.ref_wavlist.sort() | |
_, self.sr = librosa.load(self.inp_wavlist[0], sr=None) | |
self.max_audio = max_frames * 160 + 240 | |
def __len__(self): | |
return len(self.inp_wavlist) | |
def __getitem__(self, idx): | |
inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx]) | |
ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx]) | |
return inp_wavs, inp_wav, ref_wavs, ref_wav | |
def main(): | |
args = get_arg() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if args.mode == "predict_file": | |
assert args.inp_path is not None | |
assert args.ref_path is not None | |
assert args.inp_dir is None | |
assert args.ref_dir is None | |
assert args.inp_path.exists() | |
assert args.inp_path.is_file() | |
assert args.ref_path.exists() | |
assert args.ref_path.is_file() | |
inp_wavs, inp_wav = loadWav(args.inp_path) | |
ref_wavs, ref_wav = loadWav(args.ref_path) | |
scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav) | |
print("Voxsim score: ", score[0]) | |
with open(args.out_path, "w") as fw: | |
fw.write(str(score[0])) | |
else: | |
assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir." | |
assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir." | |
assert args.bs is not None, "bs is required when mode is predict_dir." | |
assert args.inp_path is None, "inp_path should be None" | |
assert args.ref_path is None, "ref_path should be None" | |
assert args.inp_dir.exists() | |
assert args.ref_dir.exists() | |
assert args.inp_dir.is_dir() | |
assert args.ref_dir.is_dir() | |
dataset = AudioDataset(args.inp_dir, args.ref_dir) | |
loader = DataLoader( | |
dataset, | |
batch_size=args.bs, | |
shuffle=False, | |
num_workers=args.num_workers) | |
scorer = Score(ckpt_path=args.ckpt_path, device=device) | |
with open(args.out_path, 'w'): | |
pass | |
for batch in tqdm.tqdm(loader): | |
scores = score.score(batch.to(device)) | |
with open(args.out_path, 'a') as fw: | |
for s in scores: | |
fw.write(str(s) + "\n") | |
print("save to ", args.out_path) | |
if __name__ == "__main__": | |
main() |