zkniu commited on
Commit
22c95cd
·
1 Parent(s): b666d33

save every line wer results and update utmos evaluation

Browse files
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -10,7 +10,7 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
-
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
@@ -56,12 +56,19 @@ def main():
56
  # --------------------------- WER ---------------------------
57
  if eval_task == "wer":
58
  wers = []
 
59
  with mp.Pool(processes=len(gpus)) as pool:
60
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
  results = pool.map(run_asr_wer, args)
62
  for wers_ in results:
63
  wers.extend(wers_)
64
 
 
 
 
 
 
 
65
  wer = round(np.mean(wers) * 100, 3)
66
  print(f"\nTotal {len(wers)} samples")
67
  print(f"WER : {wer}%")
 
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
+ import json
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
 
56
  # --------------------------- WER ---------------------------
57
  if eval_task == "wer":
58
  wers = []
59
+ wer_results = []
60
  with mp.Pool(processes=len(gpus)) as pool:
61
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
62
  results = pool.map(run_asr_wer, args)
63
  for wers_ in results:
64
  wers.extend(wers_)
65
 
66
+ with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
67
+ for line in wers:
68
+ wer_results.append(line["wer"])
69
+ json_line = json.dumps(line, ensure_ascii=False)
70
+ f.write(json_line + "\n")
71
+
72
  wer = round(np.mean(wers) * 100, 3)
73
  print(f"\nTotal {len(wers)} samples")
74
  print(f"WER : {wer}%")
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -10,7 +10,7 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
-
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
@@ -56,12 +56,19 @@ def main():
56
 
57
  if eval_task == "wer":
58
  wers = []
 
59
  with mp.Pool(processes=len(gpus)) as pool:
60
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
  results = pool.map(run_asr_wer, args)
62
  for wers_ in results:
63
  wers.extend(wers_)
64
 
 
 
 
 
 
 
65
  wer = round(np.mean(wers) * 100, 3)
66
  print(f"\nTotal {len(wers)} samples")
67
  print(f"WER : {wer}%")
 
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
+ import json
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
 
56
 
57
  if eval_task == "wer":
58
  wers = []
59
+ wer_results = []
60
  with mp.Pool(processes=len(gpus)) as pool:
61
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
62
  results = pool.map(run_asr_wer, args)
63
  for wers_ in results:
64
  wers.extend(wers_)
65
 
66
+ with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
67
+ for line in wers:
68
+ wer_results.append(line["wer"])
69
+ json_line = json.dumps(line, ensure_ascii=False)
70
+ f.write(json_line + "\n")
71
+
72
  wer = round(np.mean(wers) * 100, 3)
73
  print(f"\nTotal {len(wers)} samples")
74
  print(f"WER : {wer}%")
src/f5_tts/eval/eval_utmos.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ from pathlib import Path
4
+ import json
5
+ from tqdm import tqdm
6
+ import argparse
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
11
+ parser.add_argument(
12
+ "--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
13
+ )
14
+ parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
15
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
16
+
17
+ args = parser.parse_args()
18
+
19
+ device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
20
+
21
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
22
+ predictor = predictor.to(device)
23
+
24
+ lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
25
+ results = {}
26
+ utmos_result = 0
27
+
28
+ for line in tqdm(lines, desc="Processing"):
29
+ wave_name = line.stem
30
+ wave, sr = librosa.load(line, sr=None, mono=True)
31
+ wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
32
+ score = predictor(wave_tensor, sr)
33
+ results[str(wave_name)] = score.item()
34
+ utmos_result += score.item()
35
+
36
+ avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
37
+ print(f"UTMOS: {avg_score}")
38
+
39
+ output_path = Path(args.audio_dir) / "utmos_results.json"
40
+ with open(output_path, "w", encoding="utf-8") as f:
41
+ json.dump(results, f, ensure_ascii=False, indent=4)
42
+
43
+ print(f"Results have been saved to {output_path}")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
src/f5_tts/eval/utils_eval.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import torch.nn.functional as F
8
  import torchaudio
9
  from tqdm import tqdm
10
-
11
  from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
12
  from f5_tts.model.modules import MelSpec
13
  from f5_tts.model.utils import convert_char_to_pinyin
@@ -360,7 +360,14 @@ def run_asr_wer(args):
360
  # dele = measures["deletions"] / len(ref_list)
361
  # inse = measures["insertions"] / len(ref_list)
362
 
363
- wers.append(wer)
 
 
 
 
 
 
 
364
 
365
  return wers
366
 
 
7
  import torch.nn.functional as F
8
  import torchaudio
9
  from tqdm import tqdm
10
+ from pathlib import Path
11
  from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
12
  from f5_tts.model.modules import MelSpec
13
  from f5_tts.model.utils import convert_char_to_pinyin
 
360
  # dele = measures["deletions"] / len(ref_list)
361
  # inse = measures["insertions"] / len(ref_list)
362
 
363
+ wers.append(
364
+ {
365
+ "wav": Path(gen_wav).stem, # wav name
366
+ "truth": truth, # raw_truth
367
+ "hypo": hypo, # raw_hypo
368
+ "wer": wer, # wer score
369
+ }
370
+ )
371
 
372
  return wers
373