File size: 4,715 Bytes
2216a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f96e2ca
 
 
 
 
2216a22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import pathlib
import tqdm
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy
from score import Score
import torch

import warnings
warnings.filterwarnings("ignore")


def get_arg():
    parser = argparse.ArgumentParser()
    parser.add_argument("--bs", required=False, default=None, type=int)
    parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
    parser.add_argument("--ckpt_path", required=False, default="wavlm_ecapa.model", type=pathlib.Path)
    parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
    parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path)
    parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
    parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path)
    parser.add_argument("--out_path", required=True, type=pathlib.Path)
    parser.add_argument("--num_workers", required=False, default=0, type=int)
    return parser.parse_args()


def loadWav(filename, max_frames: int = 400):

    # Maximum audio length
    max_audio = max_frames * 160 + 240

    # Read wav file and convert to torch tensor
    if type(filename) == tuple:
        sr, audio = filename
        audio = librosa.util.normalize(audio)
    else:
        audio, sr = librosa.load(filename, sr=16000)
    audio_org = audio.copy()
    
    audiosize = audio.shape[0]

    if audiosize <= max_audio:
        shortage = max_audio - audiosize + 1
        audio       = numpy.pad(audio, (0, shortage), 'wrap')
        audiosize   = audio.shape[0]

    startframe = numpy.linspace(0,audiosize-max_audio,num=10)
    
    feats = []
    for asf in startframe:
        feats.append(audio[int(asf):int(asf)+max_audio])

    feat = numpy.stack(feats,axis=0).astype(numpy.float32)

    return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))


class AudioDataset(Dataset):
    def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
        self.inp_wavlist = list(inp_dir_path.glob("*.wav"))
        self.ref_wavlist = list(ref_dir_path.glob("*.wav"))
        assert len(self.inp_wavlist) == len(self.ref_wavlist)
        self.inp_wavlist.sort()
        self.ref_wavlist.sort()
        _, self.sr = librosa.load(self.inp_wavlist[0], sr=None)
        self.max_audio = max_frames * 160 + 240
    
    def __len__(self):
        return len(self.inp_wavlist)

    def __getitem__(self, idx):
        inp_wavs, inp_wav = loadWav(self.inp_wavlist[idx])
        ref_wavs, ref_wav = loadWav(self.ref_wavlist[idx])
        return inp_wavs, inp_wav, ref_wavs, ref_wav

def main():
    args = get_arg()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.mode == "predict_file":
        assert args.inp_path is not None
        assert args.ref_path is not None
        assert args.inp_dir is None
        assert args.ref_dir is None
        assert args.inp_path.exists()
        assert args.inp_path.is_file()
        assert args.ref_path.exists()
        assert args.ref_path.is_file()
        inp_wavs, inp_wav = loadWav(args.inp_path)
        ref_wavs, ref_wav = loadWav(args.ref_path)
        scorer = Score(ckpt_path=args.ckpt_path, device=device)
        score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
        print("Voxsim score: ", score[0])
        with open(args.out_path, "w") as fw:
            fw.write(str(score[0]))
    else:
        assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
        assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
        assert args.bs is not None, "bs is required when mode is predict_dir."
        assert args.inp_path is None, "inp_path should be None"
        assert args.ref_path is None, "ref_path should be None"
        assert args.inp_dir.exists()
        assert args.ref_dir.exists()
        assert args.inp_dir.is_dir()
        assert args.ref_dir.is_dir()
        dataset = AudioDataset(args.inp_dir, args.ref_dir)
        loader = DataLoader(
            dataset,
            batch_size=args.bs,
            shuffle=False,
            num_workers=args.num_workers)
        scorer = Score(ckpt_path=args.ckpt_path, device=device)
        with open(args.out_path, 'w'):
            pass
        for batch in tqdm.tqdm(loader):
            scores = score.score(batch.to(device))
            with open(args.out_path, 'a') as fw:
                for s in scores:
                    fw.write(str(s) + "\n")
        print("save to ", args.out_path)

if __name__ == "__main__":
    main()