File size: 2,473 Bytes
1dd0b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os, sys
import argparse
from SenseVoiceAx import SenseVoiceAx
from tokenizer import SentencepiecesTokenizer
from print_utils import rich_transcription_postprocess, rich_print_asr_res
from download_utils import download_model
import jiwer


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", required=True, type=str, help="Input dataset")
    parser.add_argument("--language", "-l", required=False, type=str, default="auto", choices=["auto", "zh", "en", "yue", "ja", "ko"])
    return parser.parse_args()


def main():
    args = get_args()

    dataset = args.dataset
    language = args.language
    use_itn = False # 标点符号预测

    model_path_root = download_model("SenseVoice")
    model_path = os.path.join(model_path_root, "sensevoice_ax650", "sensevoice.axmodel")
    bpemodel = os.path.join(model_path_root, "chn_jpn_yue_eng_ko_spectok.bpe.model")

    assert os.path.exists(model_path), f"model {model_path} not exist"

    print(f"dataset: {dataset}")
    print(f"language: {language}")
    print(f"use_itn: {use_itn}")
    print(f"model_path: {model_path}")

    tokenizer = SentencepiecesTokenizer(bpemodel=bpemodel)
    pipeline = SenseVoiceAx(model_path, language, use_itn, tokenizer=tokenizer)

    # Load dataset
    wav_names = []
    references = []
    with open(os.path.join(dataset, "ground_truth.txt"), "r") as f:
        for line in f:
            line = line.strip()
            w, r = line.split(" ")
            wav_names.append(w)
            references.append(r)

    # Iterate over dataset
    hyp = []
    wer_file = open("wer.txt", "w")
    for wav_name, reference in zip(wav_names, references):
        wav_path = os.path.join(dataset, "aishell_S0764", wav_name + ".wav")

        asr_res = pipeline.infer(wav_path, print_rtf=False)
        hypothesis = rich_print_asr_res(asr_res, will_print=False, remove_punc=True)
        hyp.append(hypothesis)

        wer = jiwer.cer(
                    reference,
                    hypothesis
                )
        
        line_content = f"{wav_name}  reference: {reference}  hypothesis: {hypothesis}  WER: {wer}"
        wer_file.write(line_content + "\n")
        print(line_content)

    total_wer = jiwer.cer(
                    references,
                    hyp
                )
    print(f"Total WER: {total_wer}")
    wer_file.write(f"Total WER: {total_wer}")
    wer_file.close()

if __name__ == "__main__":
    main()