save every line wer results and update utmos evaluation
Browse files
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -10,7 +10,7 @@ import multiprocessing as mp
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
-
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_librispeech_test,
|
16 |
run_asr_wer,
|
@@ -56,12 +56,19 @@ def main():
|
|
56 |
# --------------------------- WER ---------------------------
|
57 |
if eval_task == "wer":
|
58 |
wers = []
|
|
|
59 |
with mp.Pool(processes=len(gpus)) as pool:
|
60 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
61 |
results = pool.map(run_asr_wer, args)
|
62 |
for wers_ in results:
|
63 |
wers.extend(wers_)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
wer = round(np.mean(wers) * 100, 3)
|
66 |
print(f"\nTotal {len(wers)} samples")
|
67 |
print(f"WER : {wer}%")
|
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
+
import json
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_librispeech_test,
|
16 |
run_asr_wer,
|
|
|
56 |
# --------------------------- WER ---------------------------
|
57 |
if eval_task == "wer":
|
58 |
wers = []
|
59 |
+
wer_results = []
|
60 |
with mp.Pool(processes=len(gpus)) as pool:
|
61 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
62 |
results = pool.map(run_asr_wer, args)
|
63 |
for wers_ in results:
|
64 |
wers.extend(wers_)
|
65 |
|
66 |
+
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
|
67 |
+
for line in wers:
|
68 |
+
wer_results.append(line["wer"])
|
69 |
+
json_line = json.dumps(line, ensure_ascii=False)
|
70 |
+
f.write(json_line + "\n")
|
71 |
+
|
72 |
wer = round(np.mean(wers) * 100, 3)
|
73 |
print(f"\nTotal {len(wers)} samples")
|
74 |
print(f"WER : {wer}%")
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -10,7 +10,7 @@ import multiprocessing as mp
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
-
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_seed_tts_test,
|
16 |
run_asr_wer,
|
@@ -56,12 +56,19 @@ def main():
|
|
56 |
|
57 |
if eval_task == "wer":
|
58 |
wers = []
|
|
|
59 |
with mp.Pool(processes=len(gpus)) as pool:
|
60 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
61 |
results = pool.map(run_asr_wer, args)
|
62 |
for wers_ in results:
|
63 |
wers.extend(wers_)
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
wer = round(np.mean(wers) * 100, 3)
|
66 |
print(f"\nTotal {len(wers)} samples")
|
67 |
print(f"WER : {wer}%")
|
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
+
import json
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_seed_tts_test,
|
16 |
run_asr_wer,
|
|
|
56 |
|
57 |
if eval_task == "wer":
|
58 |
wers = []
|
59 |
+
wer_results = []
|
60 |
with mp.Pool(processes=len(gpus)) as pool:
|
61 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
62 |
results = pool.map(run_asr_wer, args)
|
63 |
for wers_ in results:
|
64 |
wers.extend(wers_)
|
65 |
|
66 |
+
with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
|
67 |
+
for line in wers:
|
68 |
+
wer_results.append(line["wer"])
|
69 |
+
json_line = json.dumps(line, ensure_ascii=False)
|
70 |
+
f.write(json_line + "\n")
|
71 |
+
|
72 |
wer = round(np.mean(wers) * 100, 3)
|
73 |
print(f"\nTotal {len(wers)} samples")
|
74 |
print(f"WER : {wer}%")
|
src/f5_tts/eval/eval_utmos.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import librosa
|
3 |
+
from pathlib import Path
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
|
11 |
+
parser.add_argument(
|
12 |
+
"--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
|
13 |
+
)
|
14 |
+
parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
|
15 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
|
20 |
+
|
21 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
22 |
+
predictor = predictor.to(device)
|
23 |
+
|
24 |
+
lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
25 |
+
results = {}
|
26 |
+
utmos_result = 0
|
27 |
+
|
28 |
+
for line in tqdm(lines, desc="Processing"):
|
29 |
+
wave_name = line.stem
|
30 |
+
wave, sr = librosa.load(line, sr=None, mono=True)
|
31 |
+
wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
|
32 |
+
score = predictor(wave_tensor, sr)
|
33 |
+
results[str(wave_name)] = score.item()
|
34 |
+
utmos_result += score.item()
|
35 |
+
|
36 |
+
avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
|
37 |
+
print(f"UTMOS: {avg_score}")
|
38 |
+
|
39 |
+
output_path = Path(args.audio_dir) / "utmos_results.json"
|
40 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
41 |
+
json.dump(results, f, ensure_ascii=False, indent=4)
|
42 |
+
|
43 |
+
print(f"Results have been saved to {output_path}")
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
main()
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
import torch.nn.functional as F
|
8 |
import torchaudio
|
9 |
from tqdm import tqdm
|
10 |
-
|
11 |
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
12 |
from f5_tts.model.modules import MelSpec
|
13 |
from f5_tts.model.utils import convert_char_to_pinyin
|
@@ -360,7 +360,14 @@ def run_asr_wer(args):
|
|
360 |
# dele = measures["deletions"] / len(ref_list)
|
361 |
# inse = measures["insertions"] / len(ref_list)
|
362 |
|
363 |
-
wers.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
return wers
|
366 |
|
|
|
7 |
import torch.nn.functional as F
|
8 |
import torchaudio
|
9 |
from tqdm import tqdm
|
10 |
+
from pathlib import Path
|
11 |
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
12 |
from f5_tts.model.modules import MelSpec
|
13 |
from f5_tts.model.utils import convert_char_to_pinyin
|
|
|
360 |
# dele = measures["deletions"] / len(ref_list)
|
361 |
# inse = measures["insertions"] / len(ref_list)
|
362 |
|
363 |
+
wers.append(
|
364 |
+
{
|
365 |
+
"wav": Path(gen_wav).stem, # wav name
|
366 |
+
"truth": truth, # raw_truth
|
367 |
+
"hypo": hypo, # raw_hypo
|
368 |
+
"wer": wer, # wer score
|
369 |
+
}
|
370 |
+
)
|
371 |
|
372 |
return wers
|
373 |
|