zkniu commited on
Commit
4dd981f
·
1 Parent(s): 9894489

support command line set args

Browse files
src/f5_tts/eval/README.md CHANGED
@@ -42,8 +42,8 @@ Then update in the following scripts with the paths you put evaluation model ckp
42
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
  ```bash
44
  # Evaluation for Seed-TTS test set
45
- python src/f5_tts/eval/eval_seedtts_testset.py
46
 
47
  # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
- python src/f5_tts/eval/eval_librispeech_test_clean.py
49
  ```
 
42
  Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
  ```bash
44
  # Evaluation for Seed-TTS test set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
46
 
47
  # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir
49
  ```
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -34,8 +34,6 @@ win_length = 1024
34
  n_fft = 1024
35
  target_rms = 0.1
36
 
37
-
38
- tokenizer = "pinyin"
39
  rel_path = str(files("f5_tts").joinpath("../../"))
40
 
41
 
@@ -49,6 +47,7 @@ def main():
49
  parser.add_argument("-n", "--expname", required=True)
50
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
51
  parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
 
52
 
53
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
54
  parser.add_argument("-o", "--odemethod", default="euler")
@@ -64,6 +63,7 @@ def main():
64
  ckpt_step = args.ckptstep
65
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
66
  mel_spec_type = args.mel_spec_type
 
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
 
34
  n_fft = 1024
35
  target_rms = 0.1
36
 
 
 
37
  rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
 
47
  parser.add_argument("-n", "--expname", required=True)
48
  parser.add_argument("-c", "--ckptstep", default=1200000, type=int)
49
  parser.add_argument("-m", "--mel_spec_type", default="vocos", type=str, choices=["bigvgan", "vocos"])
50
+ parser.add_argument("-to", "--tokenizer", default="pinyin", type=str, choices=["pinyin", "char"])
51
 
52
  parser.add_argument("-nfe", "--nfestep", default=32, type=int)
53
  parser.add_argument("-o", "--odemethod", default="euler")
 
63
  ckpt_step = args.ckptstep
64
  ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
65
  mel_spec_type = args.mel_spec_type
66
+ tokenizer = args.tokenizer
67
 
68
  nfe_step = args.nfestep
69
  ode_method = args.odemethod
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import sys
4
  import os
 
5
 
6
  sys.path.append(os.getcwd())
7
 
@@ -19,55 +20,65 @@ from f5_tts.eval.utils_eval import (
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
- eval_task = "wer" # sim | wer
23
- lang = "en"
24
- metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
25
- librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
26
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
-
28
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
29
- test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
30
-
31
- ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
32
- ## leading to a low similarity for the ground truth in some cases.
33
- # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
34
-
35
- local = False
36
- if local: # use local custom checkpoint dir
37
- asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
38
- else:
39
- asr_ckpt_dir = "" # auto download to cache dir
40
-
41
- wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
42
-
43
-
44
- # --------------------------- WER ---------------------------
45
-
46
- if eval_task == "wer":
47
- wers = []
48
-
49
- with mp.Pool(processes=len(gpus)) as pool:
50
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
51
- results = pool.map(run_asr_wer, args)
52
- for wers_ in results:
53
- wers.extend(wers_)
54
-
55
- wer = round(np.mean(wers) * 100, 3)
56
- print(f"\nTotal {len(wers)} samples")
57
- print(f"WER : {wer}%")
58
-
59
-
60
- # --------------------------- SIM ---------------------------
61
-
62
- if eval_task == "sim":
63
- sim_list = []
64
-
65
- with mp.Pool(processes=len(gpus)) as pool:
66
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
67
- results = pool.map(run_sim, args)
68
- for sim_ in results:
69
- sim_list.extend(sim_)
70
-
71
- sim = round(sum(sim_list) / len(sim_list), 3)
72
- print(f"\nTotal {len(sim_list)} samples")
73
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import sys
4
  import os
5
+ import argparse
6
 
7
  sys.path.append(os.getcwd())
8
 
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
22
 
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
26
+ parser.add_argument("-l", "--lang", type=str, default="en")
27
+ parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
28
+ parser.add_argument("-p", "--librispeech_test_clean_path", type=str, required=True)
29
+ parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
30
+ parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
31
+ return parser.parse_args()
32
+
33
+
34
+ def main():
35
+ args = get_args()
36
+ eval_task = args.eval_task
37
+ lang = args.lang
38
+ librispeech_test_clean_path = args.librispeech_test_clean_path # test-clean path
39
+ gen_wav_dir = args.gen_wav_dir
40
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
41
+
42
+ gpus = list(range(args.gpu_nums))
43
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
44
+
45
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
46
+ ## leading to a low similarity for the ground truth in some cases.
47
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
48
+
49
+ local = args.local
50
+ if local: # use local custom checkpoint dir
51
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
52
+ else:
53
+ asr_ckpt_dir = "" # auto download to cache dir
54
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
+
56
+ # --------------------------- WER ---------------------------
57
+ if eval_task == "wer":
58
+ wers = []
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_asr_wer, args)
62
+ for wers_ in results:
63
+ wers.extend(wers_)
64
+
65
+ wer = round(np.mean(wers) * 100, 3)
66
+ print(f"\nTotal {len(wers)} samples")
67
+ print(f"WER : {wer}%")
68
+
69
+ # --------------------------- SIM ---------------------------
70
+ if eval_task == "sim":
71
+ sim_list = []
72
+ with mp.Pool(processes=len(gpus)) as pool:
73
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
74
+ results = pool.map(run_sim, args)
75
+ for sim_ in results:
76
+ sim_list.extend(sim_)
77
+
78
+ sim = round(sum(sim_list) / len(sim_list), 3)
79
+ print(f"\nTotal {len(sim_list)} samples")
80
+ print(f"SIM : {sim}")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import sys
4
  import os
 
5
 
6
  sys.path.append(os.getcwd())
7
 
@@ -19,57 +20,65 @@ from f5_tts.eval.utils_eval import (
19
  rel_path = str(files("f5_tts").joinpath("../../"))
20
 
21
 
22
- eval_task = "wer" # sim | wer
23
- lang = "zh" # zh | en
24
- metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
25
- # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
26
- gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
-
28
-
29
- # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
30
- # zh 1.254 seems a result of 4 workers wer_seed_tts
31
- gpus = [0, 1, 2, 3, 4, 5, 6, 7]
32
- test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
33
-
34
- local = False
35
- if local: # use local custom checkpoint dir
36
- if lang == "zh":
37
- asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
38
- elif lang == "en":
39
- asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
40
- else:
41
- asr_ckpt_dir = "" # auto download to cache dir
42
-
43
- wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
44
-
45
-
46
- # --------------------------- WER ---------------------------
47
-
48
- if eval_task == "wer":
49
- wers = []
50
-
51
- with mp.Pool(processes=len(gpus)) as pool:
52
- args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
53
- results = pool.map(run_asr_wer, args)
54
- for wers_ in results:
55
- wers.extend(wers_)
56
-
57
- wer = round(np.mean(wers) * 100, 3)
58
- print(f"\nTotal {len(wers)} samples")
59
- print(f"WER : {wer}%")
60
-
61
-
62
- # --------------------------- SIM ---------------------------
63
-
64
- if eval_task == "sim":
65
- sim_list = []
66
-
67
- with mp.Pool(processes=len(gpus)) as pool:
68
- args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
69
- results = pool.map(run_sim, args)
70
- for sim_ in results:
71
- sim_list.extend(sim_)
72
-
73
- sim = round(sum(sim_list) / len(sim_list), 3)
74
- print(f"\nTotal {len(sim_list)} samples")
75
- print(f"SIM : {sim}")
 
 
 
 
 
 
 
 
 
2
 
3
  import sys
4
  import os
5
+ import argparse
6
 
7
  sys.path.append(os.getcwd())
8
 
 
20
  rel_path = str(files("f5_tts").joinpath("../../"))
21
 
22
 
23
+ def get_args():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("-e", "--eval_task", type=str, default="wer", choices=["sim", "wer"])
26
+ parser.add_argument("-l", "--lang", type=str, default="en", choices=["zh", "en"])
27
+ parser.add_argument("-g", "--gen_wav_dir", type=str, required=True)
28
+ parser.add_argument("-n", "--gpu_nums", type=int, default=8, help="Number of GPUs to use")
29
+ parser.add_argument("--local", action="store_true", help="Use local custom checkpoint directory")
30
+ return parser.parse_args()
31
+
32
+
33
+ def main():
34
+ args = get_args()
35
+ eval_task = args.eval_task
36
+ lang = args.lang
37
+ gen_wav_dir = args.gen_wav_dir
38
+ metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
39
+
40
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
41
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
42
+ gpus = list(range(args.gpu_nums))
43
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
44
+
45
+ local = args.local
46
+ if local: # use local custom checkpoint dir
47
+ if lang == "zh":
48
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
49
+ elif lang == "en":
50
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
51
+ else:
52
+ asr_ckpt_dir = "" # auto download to cache dir
53
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
54
+
55
+ # --------------------------- WER ---------------------------
56
+
57
+ if eval_task == "wer":
58
+ wers = []
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_asr_wer, args)
62
+ for wers_ in results:
63
+ wers.extend(wers_)
64
+
65
+ wer = round(np.mean(wers) * 100, 3)
66
+ print(f"\nTotal {len(wers)} samples")
67
+ print(f"WER : {wer}%")
68
+
69
+ # --------------------------- SIM ---------------------------
70
+ if eval_task == "sim":
71
+ sim_list = []
72
+ with mp.Pool(processes=len(gpus)) as pool:
73
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
74
+ results = pool.map(run_sim, args)
75
+ for sim_ in results:
76
+ sim_list.extend(sim_)
77
+
78
+ sim = round(sum(sim_list) / len(sim_list), 3)
79
+ print(f"\nTotal {len(sim_list)} samples")
80
+ print(f"SIM : {sim}")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ main()