File size: 4,269 Bytes
2216a22
 
 
 
08cc398
2216a22
08cc398
2216a22
 
 
 
 
 
 
08cc398
 
 
 
 
 
 
 
2216a22
 
 
 
 
08cc398
 
 
 
 
 
 
 
 
 
2216a22
08cc398
2216a22
 
 
 
 
 
08cc398
 
2216a22
 
 
 
 
 
08cc398
 
2216a22
 
08cc398
2216a22
 
 
 
 
08cc398
2216a22
08cc398
2216a22
 
 
 
 
 
 
 
 
 
08cc398
2216a22
 
 
08cc398
 
 
 
 
 
 
 
2216a22
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import pathlib
import tqdm
from torch.utils.data import Dataset, DataLoader
from score import loadWav, Score
import torch
import os

import warnings
warnings.filterwarnings("ignore")


def get_arg():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str, help="predict mode")
    parser.add_argument("--ckpt_path", required=False, default="voxsim_wavlm_ecapa.model", type=pathlib.Path, help="path to the model checkpoint")
    parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path, help="input directory when predict_dir mode")
    parser.add_argument("--ref_dir", required=False, default=None, type=pathlib.Path, help="reference directory when predict_dir mode")
    parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path, help="input file when predict_file mode")
    parser.add_argument("--ref_path", required=False, default=None, type=pathlib.Path, help="reference file when predict_file mode")
    parser.add_argument("--out_path", required=True, type=pathlib.Path, help="output path")
    parser.add_argument("--num_workers", required=False, default=4, type=int, help="number of workers for dataloader")
    return parser.parse_args()


class AudioDataset(Dataset):
    def __init__(self, inp_dir_path: pathlib.Path, ref_dir_path: pathlib.Path, max_frames: int = 400):
        self.inp_dir_path = inp_dir_path
        self.ref_dir_path = ref_dir_path
        self.inp_wavlist = [file for file in os.listdir(inp_dir_path) if file.endswith(".wav")]
        inp_wavset = set(self.inp_wavlist)
        ref_wavset = set([file for file in os.listdir(ref_dir_path) if file.endswith(".wav")])
        diff = inp_wavset - ref_wavset
        if diff:
            diff = list(diff)
            diff.sort()
            raise ValueError(f"Files {diff} are in inp_dir but not in ref_dir.")
        self.inp_wavlist.sort()

        self.max_audio = max_frames * 160 + 240
    
    def __len__(self):
        return len(self.inp_wavlist)

    def __getitem__(self, idx):
        inp_wavs, inp_wav = loadWav(os.path.join(self.inp_dir_path, self.inp_wavlist[idx]))
        ref_wavs, ref_wav = loadWav(os.path.join(self.ref_dir_path, self.inp_wavlist[idx]))
        return inp_wavs, inp_wav, ref_wavs, ref_wav

def main():
    args = get_arg()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.mode == "predict_file":
        assert args.inp_path is not None, "inp_path is required when mode is predict_file."
        assert args.ref_path is not None, "ref_path is required when mode is predict_file."
        assert args.inp_path.exists()
        assert args.ref_path.exists()
        assert args.inp_path.is_file()
        assert args.ref_path.is_file()
        inp_wavs, inp_wav = loadWav(args.inp_path)
        ref_wavs, ref_wav = loadWav(args.ref_path)
        scorer = Score(ckpt_path=args.ckpt_path, device=device)
        score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
        print("VoxSIM score: ", score)
        with open(args.out_path, "w") as fw:
            fw.write(str(score))
    else:
        assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
        assert args.ref_dir is not None, "ref_dir is required when mode is predict_dir."
        assert args.inp_dir.exists()
        assert args.ref_dir.exists()
        assert args.inp_dir.is_dir()
        assert args.ref_dir.is_dir()
        dataset = AudioDataset(args.inp_dir, args.ref_dir)
        loader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=args.num_workers)
        scorer = Score(ckpt_path=args.ckpt_path, device=device)
        avg_score = []
        with open(args.out_path, 'w') as fw:
            for batch in tqdm.tqdm(loader):
                inp_wavs, inp_wav, ref_wavs, ref_wav = batch
                score = scorer.score(inp_wavs, inp_wav, ref_wavs, ref_wav)
                avg_score.append(score)
                fw.write(str(score) + "\n")
        print("Average VoxSIM score: ", sum(avg_score)/len(avg_score))
        print("save to ", args.out_path)

if __name__ == "__main__":
    main()