Spaces:
Configuration error
Configuration error
reorganize infer_cli and stuff
Browse files- README.md +2 -2
- src/f5_tts/eval/README.md +9 -6
- src/f5_tts/eval/eval_librispeech_test_clean.py +19 -14
- src/f5_tts/eval/eval_seedtts_testset.py +18 -14
- src/f5_tts/eval/eval_utmos.py +24 -27
- src/f5_tts/eval/utils_eval.py +14 -13
- src/f5_tts/infer/README.md +3 -0
- src/f5_tts/infer/SHARED.md +19 -17
- src/f5_tts/infer/examples/basic/basic.toml +1 -1
- src/f5_tts/infer/examples/multi/story.toml +1 -0
- src/f5_tts/infer/infer_cli.py +161 -72
README.md
CHANGED
@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
|
|
147 |
## Acknowledgements
|
148 |
|
149 |
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
|
150 |
-
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
|
154 |
-
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
157 |
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
|
|
|
147 |
## Acknowledgements
|
148 |
|
149 |
- [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
|
150 |
+
- [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
|
151 |
- [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
|
152 |
- [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
|
153 |
- [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
|
154 |
+
- [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
|
155 |
- [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
|
156 |
- [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
|
157 |
- [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
|
src/f5_tts/eval/README.md
CHANGED
@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
|
|
39 |
|
40 |
### Objective Evaluation
|
41 |
|
42 |
-
Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
|
43 |
```bash
|
44 |
-
# Evaluation for Seed-TTS test set
|
45 |
-
python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <
|
46 |
|
47 |
-
# Evaluation for LibriSpeech-PC test-clean (cross-sentence)
|
48 |
-
python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <
|
49 |
-
|
|
|
|
|
|
|
|
39 |
|
40 |
### Objective Evaluation
|
41 |
|
42 |
+
Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
|
43 |
```bash
|
44 |
+
# Evaluation [WER] for Seed-TTS test [ZH] set
|
45 |
+
python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
|
46 |
|
47 |
+
# Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
|
48 |
+
python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
|
49 |
+
|
50 |
+
# Evaluation [UTMOS]. --ext: Audio extension
|
51 |
+
python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
|
52 |
+
```
|
src/f5_tts/eval/eval_librispeech_test_clean.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
3 |
-
import sys
|
4 |
-
import os
|
5 |
import argparse
|
|
|
|
|
|
|
6 |
|
7 |
sys.path.append(os.getcwd())
|
8 |
|
@@ -10,7 +11,6 @@ import multiprocessing as mp
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
-
import json
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_librispeech_test,
|
16 |
run_asr_wer,
|
@@ -54,36 +54,41 @@ def main():
|
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
# --------------------------- WER ---------------------------
|
|
|
57 |
if eval_task == "wer":
|
58 |
-
wers = []
|
59 |
wer_results = []
|
|
|
|
|
60 |
with mp.Pool(processes=len(gpus)) as pool:
|
61 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
62 |
results = pool.map(run_asr_wer, args)
|
63 |
-
for
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
json_line = json.dumps(line, ensure_ascii=False)
|
70 |
f.write(json_line + "\n")
|
71 |
|
72 |
wer = round(np.mean(wers) * 100, 3)
|
73 |
print(f"\nTotal {len(wers)} samples")
|
74 |
print(f"WER : {wer}%")
|
|
|
75 |
|
76 |
# --------------------------- SIM ---------------------------
|
|
|
77 |
if eval_task == "sim":
|
78 |
-
|
79 |
with mp.Pool(processes=len(gpus)) as pool:
|
80 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
81 |
results = pool.map(run_sim, args)
|
82 |
-
for
|
83 |
-
|
84 |
|
85 |
-
sim = round(sum(
|
86 |
-
print(f"\nTotal {len(
|
87 |
print(f"SIM : {sim}")
|
88 |
|
89 |
|
|
|
1 |
# Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
|
2 |
|
|
|
|
|
3 |
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
|
|
11 |
from importlib.resources import files
|
12 |
|
13 |
import numpy as np
|
|
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_librispeech_test,
|
16 |
run_asr_wer,
|
|
|
54 |
wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
|
55 |
|
56 |
# --------------------------- WER ---------------------------
|
57 |
+
|
58 |
if eval_task == "wer":
|
|
|
59 |
wer_results = []
|
60 |
+
wers = []
|
61 |
+
|
62 |
with mp.Pool(processes=len(gpus)) as pool:
|
63 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
64 |
results = pool.map(run_asr_wer, args)
|
65 |
+
for r in results:
|
66 |
+
wer_results.extend(r)
|
67 |
|
68 |
+
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
69 |
+
with open(wer_result_path, "w") as f:
|
70 |
+
for line in wer_results:
|
71 |
+
wers.append(line["wer"])
|
72 |
json_line = json.dumps(line, ensure_ascii=False)
|
73 |
f.write(json_line + "\n")
|
74 |
|
75 |
wer = round(np.mean(wers) * 100, 3)
|
76 |
print(f"\nTotal {len(wers)} samples")
|
77 |
print(f"WER : {wer}%")
|
78 |
+
print(f"Results have been saved to {wer_result_path}")
|
79 |
|
80 |
# --------------------------- SIM ---------------------------
|
81 |
+
|
82 |
if eval_task == "sim":
|
83 |
+
sims = []
|
84 |
with mp.Pool(processes=len(gpus)) as pool:
|
85 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
86 |
results = pool.map(run_sim, args)
|
87 |
+
for r in results:
|
88 |
+
sims.extend(r)
|
89 |
|
90 |
+
sim = round(sum(sims) / len(sims), 3)
|
91 |
+
print(f"\nTotal {len(sims)} samples")
|
92 |
print(f"SIM : {sim}")
|
93 |
|
94 |
|
src/f5_tts/eval/eval_seedtts_testset.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
3 |
-
import sys
|
4 |
-
import os
|
5 |
import argparse
|
|
|
|
|
|
|
6 |
|
7 |
sys.path.append(os.getcwd())
|
8 |
|
@@ -10,7 +11,6 @@ import multiprocessing as mp
|
|
10 |
from importlib.resources import files
|
11 |
|
12 |
import numpy as np
|
13 |
-
import json
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_seed_tts_test,
|
16 |
run_asr_wer,
|
@@ -55,35 +55,39 @@ def main():
|
|
55 |
# --------------------------- WER ---------------------------
|
56 |
|
57 |
if eval_task == "wer":
|
58 |
-
wers = []
|
59 |
wer_results = []
|
|
|
|
|
60 |
with mp.Pool(processes=len(gpus)) as pool:
|
61 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
62 |
results = pool.map(run_asr_wer, args)
|
63 |
-
for
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
json_line = json.dumps(line, ensure_ascii=False)
|
70 |
f.write(json_line + "\n")
|
71 |
|
72 |
wer = round(np.mean(wers) * 100, 3)
|
73 |
print(f"\nTotal {len(wers)} samples")
|
74 |
print(f"WER : {wer}%")
|
|
|
75 |
|
76 |
# --------------------------- SIM ---------------------------
|
|
|
77 |
if eval_task == "sim":
|
78 |
-
|
79 |
with mp.Pool(processes=len(gpus)) as pool:
|
80 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
81 |
results = pool.map(run_sim, args)
|
82 |
-
for
|
83 |
-
|
84 |
|
85 |
-
sim = round(sum(
|
86 |
-
print(f"\nTotal {len(
|
87 |
print(f"SIM : {sim}")
|
88 |
|
89 |
|
|
|
1 |
# Evaluate with Seed-TTS testset
|
2 |
|
|
|
|
|
3 |
import argparse
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
|
8 |
sys.path.append(os.getcwd())
|
9 |
|
|
|
11 |
from importlib.resources import files
|
12 |
|
13 |
import numpy as np
|
|
|
14 |
from f5_tts.eval.utils_eval import (
|
15 |
get_seed_tts_test,
|
16 |
run_asr_wer,
|
|
|
55 |
# --------------------------- WER ---------------------------
|
56 |
|
57 |
if eval_task == "wer":
|
|
|
58 |
wer_results = []
|
59 |
+
wers = []
|
60 |
+
|
61 |
with mp.Pool(processes=len(gpus)) as pool:
|
62 |
args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
|
63 |
results = pool.map(run_asr_wer, args)
|
64 |
+
for r in results:
|
65 |
+
wer_results.extend(r)
|
66 |
|
67 |
+
wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
|
68 |
+
with open(wer_result_path, "w") as f:
|
69 |
+
for line in wer_results:
|
70 |
+
wers.append(line["wer"])
|
71 |
json_line = json.dumps(line, ensure_ascii=False)
|
72 |
f.write(json_line + "\n")
|
73 |
|
74 |
wer = round(np.mean(wers) * 100, 3)
|
75 |
print(f"\nTotal {len(wers)} samples")
|
76 |
print(f"WER : {wer}%")
|
77 |
+
print(f"Results have been saved to {wer_result_path}")
|
78 |
|
79 |
# --------------------------- SIM ---------------------------
|
80 |
+
|
81 |
if eval_task == "sim":
|
82 |
+
sims = []
|
83 |
with mp.Pool(processes=len(gpus)) as pool:
|
84 |
args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
|
85 |
results = pool.map(run_sim, args)
|
86 |
+
for r in results:
|
87 |
+
sims.extend(r)
|
88 |
|
89 |
+
sim = round(sum(sims) / len(sims), 3)
|
90 |
+
print(f"\nTotal {len(sims)} samples")
|
91 |
print(f"SIM : {sim}")
|
92 |
|
93 |
|
src/f5_tts/eval/eval_utmos.py
CHANGED
@@ -1,46 +1,43 @@
|
|
1 |
-
import
|
2 |
-
import librosa
|
3 |
-
from pathlib import Path
|
4 |
import json
|
|
|
|
|
|
|
|
|
5 |
from tqdm import tqdm
|
6 |
-
import argparse
|
7 |
|
8 |
|
9 |
def main():
|
10 |
-
parser = argparse.ArgumentParser(description="
|
11 |
-
parser.add_argument(
|
12 |
-
|
13 |
-
)
|
14 |
-
parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
|
15 |
-
parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
|
16 |
-
|
17 |
args = parser.parse_args()
|
18 |
|
19 |
-
device = "cuda" if
|
20 |
|
21 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
22 |
predictor = predictor.to(device)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
28 |
-
for
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
score = predictor(
|
33 |
-
|
34 |
-
|
35 |
|
36 |
-
avg_score =
|
37 |
print(f"UTMOS: {avg_score}")
|
38 |
|
39 |
-
|
40 |
-
with open(
|
41 |
-
json.dump(
|
42 |
|
43 |
-
print(f"Results have been saved to {
|
44 |
|
45 |
|
46 |
if __name__ == "__main__":
|
|
|
1 |
+
import argparse
|
|
|
|
|
2 |
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import torch
|
7 |
from tqdm import tqdm
|
|
|
8 |
|
9 |
|
10 |
def main():
|
11 |
+
parser = argparse.ArgumentParser(description="UTMOS Evaluation")
|
12 |
+
parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
|
13 |
+
parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
|
|
|
|
|
|
|
|
|
14 |
args = parser.parse_args()
|
15 |
|
16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
19 |
predictor = predictor.to(device)
|
20 |
|
21 |
+
audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
|
22 |
+
utmos_results = {}
|
23 |
+
utmos_score = 0
|
24 |
|
25 |
+
for audio_path in tqdm(audio_paths, desc="Processing"):
|
26 |
+
wav_name = audio_path.stem
|
27 |
+
wav, sr = librosa.load(audio_path, sr=None, mono=True)
|
28 |
+
wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
|
29 |
+
score = predictor(wav_tensor, sr)
|
30 |
+
utmos_results[str(wav_name)] = score.item()
|
31 |
+
utmos_score += score.item()
|
32 |
|
33 |
+
avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
|
34 |
print(f"UTMOS: {avg_score}")
|
35 |
|
36 |
+
utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
|
37 |
+
with open(utmos_result_path, "w", encoding="utf-8") as f:
|
38 |
+
json.dump(utmos_results, f, ensure_ascii=False, indent=4)
|
39 |
|
40 |
+
print(f"Results have been saved to {utmos_result_path}")
|
41 |
|
42 |
|
43 |
if __name__ == "__main__":
|
src/f5_tts/eval/utils_eval.py
CHANGED
@@ -2,12 +2,13 @@ import math
|
|
2 |
import os
|
3 |
import random
|
4 |
import string
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
import torchaudio
|
9 |
from tqdm import tqdm
|
10 |
-
|
11 |
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
12 |
from f5_tts.model.modules import MelSpec
|
13 |
from f5_tts.model.utils import convert_char_to_pinyin
|
@@ -320,7 +321,7 @@ def run_asr_wer(args):
|
|
320 |
from zhon.hanzi import punctuation
|
321 |
|
322 |
punctuation_all = punctuation + string.punctuation
|
323 |
-
|
324 |
|
325 |
from jiwer import compute_measures
|
326 |
|
@@ -335,8 +336,8 @@ def run_asr_wer(args):
|
|
335 |
for segment in segments:
|
336 |
hypo = hypo + " " + segment.text
|
337 |
|
338 |
-
|
339 |
-
|
340 |
|
341 |
for x in punctuation_all:
|
342 |
truth = truth.replace(x, "")
|
@@ -360,16 +361,16 @@ def run_asr_wer(args):
|
|
360 |
# dele = measures["deletions"] / len(ref_list)
|
361 |
# inse = measures["insertions"] / len(ref_list)
|
362 |
|
363 |
-
|
364 |
{
|
365 |
-
"wav": Path(gen_wav).stem,
|
366 |
-
"truth":
|
367 |
-
"hypo":
|
368 |
-
"wer": wer,
|
369 |
}
|
370 |
)
|
371 |
|
372 |
-
return
|
373 |
|
374 |
|
375 |
# SIM Evaluation
|
@@ -388,7 +389,7 @@ def run_sim(args):
|
|
388 |
model = model.cuda(device)
|
389 |
model.eval()
|
390 |
|
391 |
-
|
392 |
for wav1, wav2, truth in tqdm(test_set):
|
393 |
wav1, sr1 = torchaudio.load(wav1)
|
394 |
wav2, sr2 = torchaudio.load(wav2)
|
@@ -407,6 +408,6 @@ def run_sim(args):
|
|
407 |
|
408 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
409 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
410 |
-
|
411 |
|
412 |
-
return
|
|
|
2 |
import os
|
3 |
import random
|
4 |
import string
|
5 |
+
from pathlib import Path
|
6 |
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
import torchaudio
|
10 |
from tqdm import tqdm
|
11 |
+
|
12 |
from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
|
13 |
from f5_tts.model.modules import MelSpec
|
14 |
from f5_tts.model.utils import convert_char_to_pinyin
|
|
|
321 |
from zhon.hanzi import punctuation
|
322 |
|
323 |
punctuation_all = punctuation + string.punctuation
|
324 |
+
wer_results = []
|
325 |
|
326 |
from jiwer import compute_measures
|
327 |
|
|
|
336 |
for segment in segments:
|
337 |
hypo = hypo + " " + segment.text
|
338 |
|
339 |
+
raw_truth = truth
|
340 |
+
raw_hypo = hypo
|
341 |
|
342 |
for x in punctuation_all:
|
343 |
truth = truth.replace(x, "")
|
|
|
361 |
# dele = measures["deletions"] / len(ref_list)
|
362 |
# inse = measures["insertions"] / len(ref_list)
|
363 |
|
364 |
+
wer_results.append(
|
365 |
{
|
366 |
+
"wav": Path(gen_wav).stem,
|
367 |
+
"truth": raw_truth,
|
368 |
+
"hypo": raw_hypo,
|
369 |
+
"wer": wer,
|
370 |
}
|
371 |
)
|
372 |
|
373 |
+
return wer_results
|
374 |
|
375 |
|
376 |
# SIM Evaluation
|
|
|
389 |
model = model.cuda(device)
|
390 |
model.eval()
|
391 |
|
392 |
+
sims = []
|
393 |
for wav1, wav2, truth in tqdm(test_set):
|
394 |
wav1, sr1 = torchaudio.load(wav1)
|
395 |
wav2, sr2 = torchaudio.load(wav2)
|
|
|
408 |
|
409 |
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
410 |
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
411 |
+
sims.append(sim)
|
412 |
|
413 |
+
return sims
|
src/f5_tts/infer/README.md
CHANGED
@@ -64,6 +64,9 @@ f5-tts_infer-cli \
|
|
64 |
# Choose Vocoder
|
65 |
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
66 |
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
|
|
|
|
|
|
67 |
```
|
68 |
|
69 |
And a `.toml` file would help with more flexible usage.
|
|
|
64 |
# Choose Vocoder
|
65 |
f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
|
66 |
f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
|
67 |
+
|
68 |
+
# More instructions
|
69 |
+
f5-tts_infer-cli --help
|
70 |
```
|
71 |
|
72 |
And a `.toml` file would help with more flexible usage.
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -22,12 +22,12 @@
|
|
22 |
- [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
|
23 |
- [French](#french)
|
24 |
- [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
|
|
|
|
|
25 |
- [Italian](#italian)
|
26 |
- [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
|
27 |
- [Japanese](#japanese)
|
28 |
- [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
|
29 |
-
- [Hindi](#hindi)
|
30 |
-
- [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
|
31 |
- [Mandarin](#mandarin)
|
32 |
- [Spanish](#spanish)
|
33 |
- [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
|
@@ -81,6 +81,23 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
|
|
81 |
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
## Italian
|
85 |
|
86 |
#### F5-TTS Italian @ finetune @ it
|
@@ -110,21 +127,6 @@ MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
|
|
110 |
VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
111 |
```
|
112 |
|
113 |
-
## Hindi
|
114 |
-
|
115 |
-
#### F5-TTS Small @ pretrain @ hi
|
116 |
-
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
117 |
-
|:---:|:------------:|:-----------:|:-------------:|
|
118 |
-
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
119 |
-
|
120 |
-
```bash
|
121 |
-
MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
122 |
-
VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
123 |
-
```
|
124 |
-
|
125 |
-
Authors: SPRING Lab, Indian Institute of Technology, Madras
|
126 |
-
<br>
|
127 |
-
Website: https://asr.iitm.ac.in/
|
128 |
|
129 |
## Mandarin
|
130 |
|
|
|
22 |
- [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
|
23 |
- [French](#french)
|
24 |
- [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
|
25 |
+
- [Hindi](#hindi)
|
26 |
+
- [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
|
27 |
- [Italian](#italian)
|
28 |
- [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
|
29 |
- [Japanese](#japanese)
|
30 |
- [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
|
|
|
|
|
31 |
- [Mandarin](#mandarin)
|
32 |
- [Spanish](#spanish)
|
33 |
- [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
|
|
|
81 |
- [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
|
82 |
|
83 |
|
84 |
+
## Hindi
|
85 |
+
|
86 |
+
#### F5-TTS Small @ pretrain @ hi
|
87 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
88 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
89 |
+
|F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
|
90 |
+
|
91 |
+
```bash
|
92 |
+
MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
|
93 |
+
VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
|
94 |
+
```
|
95 |
+
|
96 |
+
Authors: SPRING Lab, Indian Institute of Technology, Madras
|
97 |
+
<br>
|
98 |
+
Website: https://asr.iitm.ac.in/
|
99 |
+
|
100 |
+
|
101 |
## Italian
|
102 |
|
103 |
#### F5-TTS Italian @ finetune @ it
|
|
|
127 |
VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
128 |
```
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
## Mandarin
|
132 |
|
src/f5_tts/infer/examples/basic/basic.toml
CHANGED
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
|
|
8 |
gen_file = ""
|
9 |
remove_silence = false
|
10 |
output_dir = "tests"
|
11 |
-
output_file = "
|
|
|
8 |
gen_file = ""
|
9 |
remove_silence = false
|
10 |
output_dir = "tests"
|
11 |
+
output_file = "infer_cli_basic.wav"
|
src/f5_tts/infer/examples/multi/story.toml
CHANGED
@@ -8,6 +8,7 @@ gen_text = ""
|
|
8 |
gen_file = "infer/examples/multi/story.txt"
|
9 |
remove_silence = true
|
10 |
output_dir = "tests"
|
|
|
11 |
|
12 |
[voices.town]
|
13 |
ref_audio = "infer/examples/multi/town.flac"
|
|
|
8 |
gen_file = "infer/examples/multi/story.txt"
|
9 |
remove_silence = true
|
10 |
output_dir = "tests"
|
11 |
+
output_file = "infer_cli_story.wav"
|
12 |
|
13 |
[voices.town]
|
14 |
ref_audio = "infer/examples/multi/town.flac"
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -2,6 +2,7 @@ import argparse
|
|
2 |
import codecs
|
3 |
import os
|
4 |
import re
|
|
|
5 |
from importlib.resources import files
|
6 |
from pathlib import Path
|
7 |
|
@@ -11,6 +12,14 @@ import tomli
|
|
11 |
from cached_path import cached_path
|
12 |
|
13 |
from f5_tts.infer.utils_infer import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
infer_process,
|
15 |
load_model,
|
16 |
load_vocoder,
|
@@ -19,6 +28,7 @@ from f5_tts.infer.utils_infer import (
|
|
19 |
)
|
20 |
from f5_tts.model import DiT, UNetT
|
21 |
|
|
|
22 |
parser = argparse.ArgumentParser(
|
23 |
prog="python3 infer-cli.py",
|
24 |
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
@@ -27,86 +37,161 @@ parser = argparse.ArgumentParser(
|
|
27 |
parser.add_argument(
|
28 |
"-c",
|
29 |
"--config",
|
30 |
-
|
31 |
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
|
|
32 |
)
|
|
|
|
|
|
|
|
|
33 |
parser.add_argument(
|
34 |
"-m",
|
35 |
"--model",
|
36 |
-
|
|
|
37 |
)
|
38 |
parser.add_argument(
|
39 |
"-p",
|
40 |
"--ckpt_file",
|
41 |
-
|
|
|
42 |
)
|
43 |
parser.add_argument(
|
44 |
"-v",
|
45 |
"--vocab_file",
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
)
|
48 |
-
parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
|
49 |
-
parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
|
50 |
parser.add_argument(
|
51 |
"-t",
|
52 |
"--gen_text",
|
53 |
type=str,
|
54 |
-
help="
|
55 |
)
|
56 |
parser.add_argument(
|
57 |
"-f",
|
58 |
"--gen_file",
|
59 |
type=str,
|
60 |
-
help="
|
61 |
)
|
62 |
parser.add_argument(
|
63 |
"-o",
|
64 |
"--output_dir",
|
65 |
type=str,
|
66 |
-
help="
|
67 |
)
|
68 |
parser.add_argument(
|
69 |
"-w",
|
70 |
"--output_file",
|
71 |
type=str,
|
72 |
-
help="
|
73 |
)
|
74 |
parser.add_argument(
|
75 |
"--save_chunk",
|
76 |
action="store_true",
|
77 |
-
help="
|
78 |
)
|
79 |
parser.add_argument(
|
80 |
"--remove_silence",
|
81 |
-
|
|
|
82 |
)
|
83 |
-
parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
|
84 |
parser.add_argument(
|
85 |
"--load_vocoder_from_local",
|
86 |
action="store_true",
|
87 |
-
help="load vocoder from local
|
88 |
)
|
89 |
parser.add_argument(
|
90 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
type=float,
|
92 |
-
|
93 |
-
help="Adjust the speed of the audio generation (default: 1.0)",
|
94 |
)
|
95 |
parser.add_argument(
|
96 |
"--nfe_step",
|
97 |
type=int,
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
)
|
101 |
args = parser.parse_args()
|
102 |
|
|
|
|
|
|
|
103 |
config = tomli.load(open(args.config, "rb"))
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# patches for pip pkg user
|
112 |
if "infer/examples/" in ref_audio:
|
@@ -119,35 +204,39 @@ if "voices" in config:
|
|
119 |
if "infer/examples/" in voice_ref_audio:
|
120 |
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
121 |
|
|
|
|
|
|
|
122 |
if gen_file:
|
123 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
128 |
-
vocab_file = args.vocab_file if args.vocab_file else ""
|
129 |
-
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
130 |
-
speed = args.speed
|
131 |
-
nfe_step = args.nfe_step
|
132 |
|
133 |
wave_path = Path(output_dir) / output_file
|
134 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
vocoder_name = args.vocoder_name
|
137 |
-
mel_spec_type = args.vocoder_name
|
138 |
if vocoder_name == "vocos":
|
139 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
140 |
elif vocoder_name == "bigvgan":
|
141 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
142 |
|
143 |
-
vocoder = load_vocoder(vocoder_name=
|
|
|
144 |
|
|
|
145 |
|
146 |
-
# load models
|
147 |
if model == "F5-TTS":
|
148 |
model_cls = DiT
|
149 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
150 |
-
if ckpt_file
|
151 |
if vocoder_name == "vocos":
|
152 |
repo_name = "F5-TTS"
|
153 |
exp_name = "F5TTS_Base"
|
@@ -164,19 +253,21 @@ elif model == "E2-TTS":
|
|
164 |
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
|
165 |
model_cls = UNetT
|
166 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
167 |
-
if ckpt_file
|
168 |
repo_name = "E2-TTS"
|
169 |
exp_name = "E2TTS_Base"
|
170 |
ckpt_step = 1200000
|
171 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
172 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
173 |
|
174 |
-
|
175 |
print(f"Using {model}...")
|
176 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=
|
|
|
177 |
|
|
|
178 |
|
179 |
-
|
|
|
180 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
181 |
if "voices" not in config:
|
182 |
voices = {"main": main_voice}
|
@@ -184,16 +275,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
184 |
voices = config["voices"]
|
185 |
voices["main"] = main_voice
|
186 |
for voice in voices:
|
|
|
|
|
187 |
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
188 |
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
189 |
)
|
190 |
-
print("
|
191 |
-
print("Ref_audio:", voices[voice]["ref_audio"])
|
192 |
-
print("Ref_text:", voices[voice]["ref_text"])
|
193 |
|
194 |
generated_audio_segments = []
|
195 |
reg1 = r"(?=\[\w+\])"
|
196 |
-
chunks = re.split(reg1,
|
197 |
reg2 = r"\[(\w+)\]"
|
198 |
for text in chunks:
|
199 |
if not text.strip():
|
@@ -208,21 +299,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
208 |
print(f"Voice {voice} not found, using main.")
|
209 |
voice = "main"
|
210 |
text = re.sub(reg2, "", text)
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
print(f"Voice: {voice}")
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
vocoder,
|
221 |
-
mel_spec_type=
|
222 |
-
|
|
|
223 |
nfe_step=nfe_step,
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
-
generated_audio_segments.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
if generated_audio_segments:
|
228 |
final_wave = np.concatenate(generated_audio_segments)
|
@@ -236,22 +341,6 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
|
|
236 |
if remove_silence:
|
237 |
remove_silence_for_generated_wav(f.name)
|
238 |
print(f.name)
|
239 |
-
# Ensure the gen_text chunk directory exists
|
240 |
-
|
241 |
-
if save_chunk:
|
242 |
-
gen_text_chunk_dir = os.path.join(output_dir, "chunks")
|
243 |
-
if not os.path.exists(gen_text_chunk_dir): # if Not create directory
|
244 |
-
os.makedirs(gen_text_chunk_dir)
|
245 |
-
|
246 |
-
# Save individual chunks as separate files
|
247 |
-
for idx, segment in enumerate(generated_audio_segments):
|
248 |
-
gen_text_chunk_path = os.path.join(output_dir, gen_text_chunk_dir, f"chunk_{idx}.wav")
|
249 |
-
sf.write(gen_text_chunk_path, segment, final_sample_rate)
|
250 |
-
print(f"Saved gen_text chunk {idx} at {gen_text_chunk_path}")
|
251 |
-
|
252 |
-
|
253 |
-
def main():
|
254 |
-
main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
|
255 |
|
256 |
|
257 |
if __name__ == "__main__":
|
|
|
2 |
import codecs
|
3 |
import os
|
4 |
import re
|
5 |
+
from datetime import datetime
|
6 |
from importlib.resources import files
|
7 |
from pathlib import Path
|
8 |
|
|
|
12 |
from cached_path import cached_path
|
13 |
|
14 |
from f5_tts.infer.utils_infer import (
|
15 |
+
mel_spec_type,
|
16 |
+
target_rms,
|
17 |
+
cross_fade_duration,
|
18 |
+
nfe_step,
|
19 |
+
cfg_strength,
|
20 |
+
sway_sampling_coef,
|
21 |
+
speed,
|
22 |
+
fix_duration,
|
23 |
infer_process,
|
24 |
load_model,
|
25 |
load_vocoder,
|
|
|
28 |
)
|
29 |
from f5_tts.model import DiT, UNetT
|
30 |
|
31 |
+
|
32 |
parser = argparse.ArgumentParser(
|
33 |
prog="python3 infer-cli.py",
|
34 |
description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
|
|
|
37 |
parser.add_argument(
|
38 |
"-c",
|
39 |
"--config",
|
40 |
+
type=str,
|
41 |
default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
|
42 |
+
help="The configuration file, default see infer/examples/basic/basic.toml",
|
43 |
)
|
44 |
+
|
45 |
+
|
46 |
+
# Note. Not to provide default value here in order to read default from config file
|
47 |
+
|
48 |
parser.add_argument(
|
49 |
"-m",
|
50 |
"--model",
|
51 |
+
type=str,
|
52 |
+
help="The model name: F5-TTS | E2-TTS",
|
53 |
)
|
54 |
parser.add_argument(
|
55 |
"-p",
|
56 |
"--ckpt_file",
|
57 |
+
type=str,
|
58 |
+
help="The path to model checkpoint .pt, leave blank to use default",
|
59 |
)
|
60 |
parser.add_argument(
|
61 |
"-v",
|
62 |
"--vocab_file",
|
63 |
+
type=str,
|
64 |
+
help="The path to vocab file .txt, leave blank to use default",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"-r",
|
68 |
+
"--ref_audio",
|
69 |
+
type=str,
|
70 |
+
help="The reference audio file.",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"-s",
|
74 |
+
"--ref_text",
|
75 |
+
type=str,
|
76 |
+
help="The transcript/subtitle for the reference audio",
|
77 |
)
|
|
|
|
|
78 |
parser.add_argument(
|
79 |
"-t",
|
80 |
"--gen_text",
|
81 |
type=str,
|
82 |
+
help="The text to make model synthesize a speech",
|
83 |
)
|
84 |
parser.add_argument(
|
85 |
"-f",
|
86 |
"--gen_file",
|
87 |
type=str,
|
88 |
+
help="The file with text to generate, will ignore --gen_text",
|
89 |
)
|
90 |
parser.add_argument(
|
91 |
"-o",
|
92 |
"--output_dir",
|
93 |
type=str,
|
94 |
+
help="The path to output folder",
|
95 |
)
|
96 |
parser.add_argument(
|
97 |
"-w",
|
98 |
"--output_file",
|
99 |
type=str,
|
100 |
+
help="The name of output file",
|
101 |
)
|
102 |
parser.add_argument(
|
103 |
"--save_chunk",
|
104 |
action="store_true",
|
105 |
+
help="To save each audio chunks during inference",
|
106 |
)
|
107 |
parser.add_argument(
|
108 |
"--remove_silence",
|
109 |
+
action="store_true",
|
110 |
+
help="To remove long silence found in ouput",
|
111 |
)
|
|
|
112 |
parser.add_argument(
|
113 |
"--load_vocoder_from_local",
|
114 |
action="store_true",
|
115 |
+
help="To load vocoder from local dir, default to ../checkpoints/charactr/vocos-mel-24khz",
|
116 |
)
|
117 |
parser.add_argument(
|
118 |
+
"--vocoder_name",
|
119 |
+
type=str,
|
120 |
+
choices=["vocos", "bigvgan"],
|
121 |
+
help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--target_rms",
|
125 |
+
type=float,
|
126 |
+
help=f"Target output speech loudness normalization value, default {target_rms}",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--cross_fade_duration",
|
130 |
type=float,
|
131 |
+
help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
|
|
|
132 |
)
|
133 |
parser.add_argument(
|
134 |
"--nfe_step",
|
135 |
type=int,
|
136 |
+
help=f"The number of function evaluation (denoising steps), default {nfe_step}",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--cfg_strength",
|
140 |
+
type=float,
|
141 |
+
help=f"Classifier-free guidance strength, default {cfg_strength}",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--sway_sampling_coef",
|
145 |
+
type=float,
|
146 |
+
help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--speed",
|
150 |
+
type=float,
|
151 |
+
help=f"The speed of the generated audio, default {speed}",
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--fix_duration",
|
155 |
+
type=float,
|
156 |
+
help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
|
157 |
)
|
158 |
args = parser.parse_args()
|
159 |
|
160 |
+
|
161 |
+
# config file
|
162 |
+
|
163 |
config = tomli.load(open(args.config, "rb"))
|
164 |
|
165 |
+
|
166 |
+
# command-line interface parameters
|
167 |
+
|
168 |
+
model = args.model or config.get("model", "F5-TTS")
|
169 |
+
ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
|
170 |
+
vocab_file = args.vocab_file or config.get("vocab_file", "")
|
171 |
+
|
172 |
+
ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
|
173 |
+
ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
|
174 |
+
gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
|
175 |
+
gen_file = args.gen_file or config.get("gen_file", "")
|
176 |
+
|
177 |
+
output_dir = args.output_dir or config.get("output_dir", "tests")
|
178 |
+
output_file = args.output_file or config.get(
|
179 |
+
"output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
|
180 |
+
)
|
181 |
+
|
182 |
+
save_chunk = args.save_chunk
|
183 |
+
remove_silence = args.remove_silence
|
184 |
+
load_vocoder_from_local = args.load_vocoder_from_local
|
185 |
+
|
186 |
+
vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
|
187 |
+
target_rms = args.target_rms or config.get("target_rms", target_rms)
|
188 |
+
cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
|
189 |
+
nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
|
190 |
+
cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
|
191 |
+
sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
|
192 |
+
speed = args.speed or config.get("speed", speed)
|
193 |
+
fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
|
194 |
+
|
195 |
|
196 |
# patches for pip pkg user
|
197 |
if "infer/examples/" in ref_audio:
|
|
|
204 |
if "infer/examples/" in voice_ref_audio:
|
205 |
config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
|
206 |
|
207 |
+
|
208 |
+
# ignore gen_text if gen_file provided
|
209 |
+
|
210 |
if gen_file:
|
211 |
gen_text = codecs.open(gen_file, "r", "utf-8").read()
|
212 |
+
|
213 |
+
|
214 |
+
# output path
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
wave_path = Path(output_dir) / output_file
|
217 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
218 |
+
if save_chunk:
|
219 |
+
output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
|
220 |
+
if not os.path.exists(output_chunk_dir):
|
221 |
+
os.makedirs(output_chunk_dir)
|
222 |
+
|
223 |
+
|
224 |
+
# load vocoder
|
225 |
|
|
|
|
|
226 |
if vocoder_name == "vocos":
|
227 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
228 |
elif vocoder_name == "bigvgan":
|
229 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
230 |
|
231 |
+
vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
|
232 |
+
|
233 |
|
234 |
+
# load TTS model
|
235 |
|
|
|
236 |
if model == "F5-TTS":
|
237 |
model_cls = DiT
|
238 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
239 |
+
if not ckpt_file: # path not specified, download from repo
|
240 |
if vocoder_name == "vocos":
|
241 |
repo_name = "F5-TTS"
|
242 |
exp_name = "F5TTS_Base"
|
|
|
253 |
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
|
254 |
model_cls = UNetT
|
255 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
256 |
+
if not ckpt_file: # path not specified, download from repo
|
257 |
repo_name = "E2-TTS"
|
258 |
exp_name = "E2TTS_Base"
|
259 |
ckpt_step = 1200000
|
260 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
261 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
262 |
|
|
|
263 |
print(f"Using {model}...")
|
264 |
+
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
|
265 |
+
|
266 |
|
267 |
+
# inference process
|
268 |
|
269 |
+
|
270 |
+
def main():
|
271 |
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
|
272 |
if "voices" not in config:
|
273 |
voices = {"main": main_voice}
|
|
|
275 |
voices = config["voices"]
|
276 |
voices["main"] = main_voice
|
277 |
for voice in voices:
|
278 |
+
print("Voice:", voice)
|
279 |
+
print("ref_audio ", voices[voice]["ref_audio"])
|
280 |
voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
|
281 |
voices[voice]["ref_audio"], voices[voice]["ref_text"]
|
282 |
)
|
283 |
+
print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
|
|
|
|
|
284 |
|
285 |
generated_audio_segments = []
|
286 |
reg1 = r"(?=\[\w+\])"
|
287 |
+
chunks = re.split(reg1, gen_text)
|
288 |
reg2 = r"\[(\w+)\]"
|
289 |
for text in chunks:
|
290 |
if not text.strip():
|
|
|
299 |
print(f"Voice {voice} not found, using main.")
|
300 |
voice = "main"
|
301 |
text = re.sub(reg2, "", text)
|
302 |
+
ref_audio_ = voices[voice]["ref_audio"]
|
303 |
+
ref_text_ = voices[voice]["ref_text"]
|
304 |
+
gen_text_ = text.strip()
|
305 |
print(f"Voice: {voice}")
|
306 |
+
audio_segment, final_sample_rate, spectragram = infer_process(
|
307 |
+
ref_audio_,
|
308 |
+
ref_text_,
|
309 |
+
gen_text_,
|
310 |
+
ema_model,
|
311 |
vocoder,
|
312 |
+
mel_spec_type=vocoder_name,
|
313 |
+
target_rms=target_rms,
|
314 |
+
cross_fade_duration=cross_fade_duration,
|
315 |
nfe_step=nfe_step,
|
316 |
+
cfg_strength=cfg_strength,
|
317 |
+
sway_sampling_coef=sway_sampling_coef,
|
318 |
+
speed=speed,
|
319 |
+
fix_duration=fix_duration,
|
320 |
)
|
321 |
+
generated_audio_segments.append(audio_segment)
|
322 |
+
|
323 |
+
if save_chunk:
|
324 |
+
if len(gen_text_) > 200:
|
325 |
+
gen_text_ = gen_text_[:200] + " ... "
|
326 |
+
sf.write(
|
327 |
+
os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
|
328 |
+
audio_segment,
|
329 |
+
final_sample_rate,
|
330 |
+
)
|
331 |
|
332 |
if generated_audio_segments:
|
333 |
final_wave = np.concatenate(generated_audio_segments)
|
|
|
341 |
if remove_silence:
|
342 |
remove_silence_for_generated_wav(f.name)
|
343 |
print(f.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
|
346 |
if __name__ == "__main__":
|