SWivid commited on
Commit
b7bc641
·
1 Parent(s): 3c60f99

reorganize infer_cli and stuff

Browse files
README.md CHANGED
@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
- - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
154
- - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
 
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
  - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
154
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
src/f5_tts/eval/README.md CHANGED
@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
39
 
40
  ### Objective Evaluation
41
 
42
- Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
  ```bash
44
- # Evaluation for Seed-TTS test set
45
- python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
46
 
47
- # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
- python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
- ```
 
 
 
 
39
 
40
  ### Objective Evaluation
41
 
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
43
  ```bash
44
+ # Evaluation [WER] for Seed-TTS test [ZH] set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
46
 
47
+ # Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
+
50
+ # Evaluation [UTMOS]. --ext: Audio extension
51
+ python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
52
+ ```
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -1,8 +1,9 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys
4
- import os
5
  import argparse
 
 
 
6
 
7
  sys.path.append(os.getcwd())
8
 
@@ -10,7 +11,6 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
- import json
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
@@ -54,36 +54,41 @@ def main():
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
  # --------------------------- WER ---------------------------
 
57
  if eval_task == "wer":
58
- wers = []
59
  wer_results = []
 
 
60
  with mp.Pool(processes=len(gpus)) as pool:
61
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
62
  results = pool.map(run_asr_wer, args)
63
- for wers_ in results:
64
- wers.extend(wers_)
65
 
66
- with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
67
- for line in wers:
68
- wer_results.append(line["wer"])
 
69
  json_line = json.dumps(line, ensure_ascii=False)
70
  f.write(json_line + "\n")
71
 
72
  wer = round(np.mean(wers) * 100, 3)
73
  print(f"\nTotal {len(wers)} samples")
74
  print(f"WER : {wer}%")
 
75
 
76
  # --------------------------- SIM ---------------------------
 
77
  if eval_task == "sim":
78
- sim_list = []
79
  with mp.Pool(processes=len(gpus)) as pool:
80
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
81
  results = pool.map(run_sim, args)
82
- for sim_ in results:
83
- sim_list.extend(sim_)
84
 
85
- sim = round(sum(sim_list) / len(sim_list), 3)
86
- print(f"\nTotal {len(sim_list)} samples")
87
  print(f"SIM : {sim}")
88
 
89
 
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
 
 
3
  import argparse
4
+ import json
5
+ import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
 
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
  # --------------------------- WER ---------------------------
57
+
58
  if eval_task == "wer":
 
59
  wer_results = []
60
+ wers = []
61
+
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
+ for r in results:
66
+ wer_results.extend(r)
67
 
68
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
+ with open(wer_result_path, "w") as f:
70
+ for line in wer_results:
71
+ wers.append(line["wer"])
72
  json_line = json.dumps(line, ensure_ascii=False)
73
  f.write(json_line + "\n")
74
 
75
  wer = round(np.mean(wers) * 100, 3)
76
  print(f"\nTotal {len(wers)} samples")
77
  print(f"WER : {wer}%")
78
+ print(f"Results have been saved to {wer_result_path}")
79
 
80
  # --------------------------- SIM ---------------------------
81
+
82
  if eval_task == "sim":
83
+ sims = []
84
  with mp.Pool(processes=len(gpus)) as pool:
85
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
  results = pool.map(run_sim, args)
87
+ for r in results:
88
+ sims.extend(r)
89
 
90
+ sim = round(sum(sims) / len(sims), 3)
91
+ print(f"\nTotal {len(sims)} samples")
92
  print(f"SIM : {sim}")
93
 
94
 
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -1,8 +1,9 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys
4
- import os
5
  import argparse
 
 
 
6
 
7
  sys.path.append(os.getcwd())
8
 
@@ -10,7 +11,6 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
- import json
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
@@ -55,35 +55,39 @@ def main():
55
  # --------------------------- WER ---------------------------
56
 
57
  if eval_task == "wer":
58
- wers = []
59
  wer_results = []
 
 
60
  with mp.Pool(processes=len(gpus)) as pool:
61
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
62
  results = pool.map(run_asr_wer, args)
63
- for wers_ in results:
64
- wers.extend(wers_)
65
 
66
- with open(f"{gen_wav_dir}/{lang}_wer_results.jsonl", "w") as f:
67
- for line in wers:
68
- wer_results.append(line["wer"])
 
69
  json_line = json.dumps(line, ensure_ascii=False)
70
  f.write(json_line + "\n")
71
 
72
  wer = round(np.mean(wers) * 100, 3)
73
  print(f"\nTotal {len(wers)} samples")
74
  print(f"WER : {wer}%")
 
75
 
76
  # --------------------------- SIM ---------------------------
 
77
  if eval_task == "sim":
78
- sim_list = []
79
  with mp.Pool(processes=len(gpus)) as pool:
80
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
81
  results = pool.map(run_sim, args)
82
- for sim_ in results:
83
- sim_list.extend(sim_)
84
 
85
- sim = round(sum(sim_list) / len(sim_list), 3)
86
- print(f"\nTotal {len(sim_list)} samples")
87
  print(f"SIM : {sim}")
88
 
89
 
 
1
  # Evaluate with Seed-TTS testset
2
 
 
 
3
  import argparse
4
+ import json
5
+ import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
 
55
  # --------------------------- WER ---------------------------
56
 
57
  if eval_task == "wer":
 
58
  wer_results = []
59
+ wers = []
60
+
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
+ for r in results:
65
+ wer_results.extend(r)
66
 
67
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
+ with open(wer_result_path, "w") as f:
69
+ for line in wer_results:
70
+ wers.append(line["wer"])
71
  json_line = json.dumps(line, ensure_ascii=False)
72
  f.write(json_line + "\n")
73
 
74
  wer = round(np.mean(wers) * 100, 3)
75
  print(f"\nTotal {len(wers)} samples")
76
  print(f"WER : {wer}%")
77
+ print(f"Results have been saved to {wer_result_path}")
78
 
79
  # --------------------------- SIM ---------------------------
80
+
81
  if eval_task == "sim":
82
+ sims = []
83
  with mp.Pool(processes=len(gpus)) as pool:
84
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
  results = pool.map(run_sim, args)
86
+ for r in results:
87
+ sims.extend(r)
88
 
89
+ sim = round(sum(sims) / len(sims), 3)
90
+ print(f"\nTotal {len(sims)} samples")
91
  print(f"SIM : {sim}")
92
 
93
 
src/f5_tts/eval/eval_utmos.py CHANGED
@@ -1,46 +1,43 @@
1
- import torch
2
- import librosa
3
- from pathlib import Path
4
  import json
 
 
 
 
5
  from tqdm import tqdm
6
- import argparse
7
 
8
 
9
  def main():
10
- parser = argparse.ArgumentParser(description="Evaluate UTMOS scores for audio files.")
11
- parser.add_argument(
12
- "--audio_dir", type=str, required=True, help="Path to the directory containing WAV audio files."
13
- )
14
- parser.add_argument("--ext", type=str, default="wav", help="audio extension.")
15
- parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on (e.g. 'cuda' or 'cpu').")
16
-
17
  args = parser.parse_args()
18
 
19
- device = "cuda" if args.device and torch.cuda.is_available() else "cpu"
20
 
21
  predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
22
  predictor = predictor.to(device)
23
 
24
- lines = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
25
- results = {}
26
- utmos_result = 0
27
 
28
- for line in tqdm(lines, desc="Processing"):
29
- wave_name = line.stem
30
- wave, sr = librosa.load(line, sr=None, mono=True)
31
- wave_tensor = torch.from_numpy(wave).to(device).unsqueeze(0)
32
- score = predictor(wave_tensor, sr)
33
- results[str(wave_name)] = score.item()
34
- utmos_result += score.item()
35
 
36
- avg_score = utmos_result / len(lines) if len(lines) > 0 else 0
37
  print(f"UTMOS: {avg_score}")
38
 
39
- output_path = Path(args.audio_dir) / "utmos_results.json"
40
- with open(output_path, "w", encoding="utf-8") as f:
41
- json.dump(results, f, ensure_ascii=False, indent=4)
42
 
43
- print(f"Results have been saved to {output_path}")
44
 
45
 
46
  if __name__ == "__main__":
 
1
+ import argparse
 
 
2
  import json
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import torch
7
  from tqdm import tqdm
 
8
 
9
 
10
  def main():
11
+ parser = argparse.ArgumentParser(description="UTMOS Evaluation")
12
+ parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
13
+ parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
 
 
 
 
14
  args = parser.parse_args()
15
 
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
19
  predictor = predictor.to(device)
20
 
21
+ audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
+ utmos_results = {}
23
+ utmos_score = 0
24
 
25
+ for audio_path in tqdm(audio_paths, desc="Processing"):
26
+ wav_name = audio_path.stem
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ utmos_results[str(wav_name)] = score.item()
31
+ utmos_score += score.item()
32
 
33
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
  print(f"UTMOS: {avg_score}")
35
 
36
+ utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
+ with open(utmos_result_path, "w", encoding="utf-8") as f:
38
+ json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
 
40
+ print(f"Results have been saved to {utmos_result_path}")
41
 
42
 
43
  if __name__ == "__main__":
src/f5_tts/eval/utils_eval.py CHANGED
@@ -2,12 +2,13 @@ import math
2
  import os
3
  import random
4
  import string
 
5
 
6
  import torch
7
  import torch.nn.functional as F
8
  import torchaudio
9
  from tqdm import tqdm
10
- from pathlib import Path
11
  from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
12
  from f5_tts.model.modules import MelSpec
13
  from f5_tts.model.utils import convert_char_to_pinyin
@@ -320,7 +321,7 @@ def run_asr_wer(args):
320
  from zhon.hanzi import punctuation
321
 
322
  punctuation_all = punctuation + string.punctuation
323
- wers = []
324
 
325
  from jiwer import compute_measures
326
 
@@ -335,8 +336,8 @@ def run_asr_wer(args):
335
  for segment in segments:
336
  hypo = hypo + " " + segment.text
337
 
338
- # raw_truth = truth
339
- # raw_hypo = hypo
340
 
341
  for x in punctuation_all:
342
  truth = truth.replace(x, "")
@@ -360,16 +361,16 @@ def run_asr_wer(args):
360
  # dele = measures["deletions"] / len(ref_list)
361
  # inse = measures["insertions"] / len(ref_list)
362
 
363
- wers.append(
364
  {
365
- "wav": Path(gen_wav).stem, # wav name
366
- "truth": truth, # raw_truth
367
- "hypo": hypo, # raw_hypo
368
- "wer": wer, # wer score
369
  }
370
  )
371
 
372
- return wers
373
 
374
 
375
  # SIM Evaluation
@@ -388,7 +389,7 @@ def run_sim(args):
388
  model = model.cuda(device)
389
  model.eval()
390
 
391
- sim_list = []
392
  for wav1, wav2, truth in tqdm(test_set):
393
  wav1, sr1 = torchaudio.load(wav1)
394
  wav2, sr2 = torchaudio.load(wav2)
@@ -407,6 +408,6 @@ def run_sim(args):
407
 
408
  sim = F.cosine_similarity(emb1, emb2)[0].item()
409
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
410
- sim_list.append(sim)
411
 
412
- return sim_list
 
2
  import os
3
  import random
4
  import string
5
+ from pathlib import Path
6
 
7
  import torch
8
  import torch.nn.functional as F
9
  import torchaudio
10
  from tqdm import tqdm
11
+
12
  from f5_tts.eval.ecapa_tdnn import ECAPA_TDNN_SMALL
13
  from f5_tts.model.modules import MelSpec
14
  from f5_tts.model.utils import convert_char_to_pinyin
 
321
  from zhon.hanzi import punctuation
322
 
323
  punctuation_all = punctuation + string.punctuation
324
+ wer_results = []
325
 
326
  from jiwer import compute_measures
327
 
 
336
  for segment in segments:
337
  hypo = hypo + " " + segment.text
338
 
339
+ raw_truth = truth
340
+ raw_hypo = hypo
341
 
342
  for x in punctuation_all:
343
  truth = truth.replace(x, "")
 
361
  # dele = measures["deletions"] / len(ref_list)
362
  # inse = measures["insertions"] / len(ref_list)
363
 
364
+ wer_results.append(
365
  {
366
+ "wav": Path(gen_wav).stem,
367
+ "truth": raw_truth,
368
+ "hypo": raw_hypo,
369
+ "wer": wer,
370
  }
371
  )
372
 
373
+ return wer_results
374
 
375
 
376
  # SIM Evaluation
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sims = []
393
  for wav1, wav2, truth in tqdm(test_set):
394
  wav1, sr1 = torchaudio.load(wav1)
395
  wav2, sr2 = torchaudio.load(wav2)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sims.append(sim)
412
 
413
+ return sims
src/f5_tts/infer/README.md CHANGED
@@ -64,6 +64,9 @@ f5-tts_infer-cli \
64
  # Choose Vocoder
65
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
 
67
  ```
68
 
69
  And a `.toml` file would help with more flexible usage.
 
64
  # Choose Vocoder
65
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
67
+
68
+ # More instructions
69
+ f5-tts_infer-cli --help
70
  ```
71
 
72
  And a `.toml` file would help with more flexible usage.
src/f5_tts/infer/SHARED.md CHANGED
@@ -22,12 +22,12 @@
22
  - [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
23
  - [French](#french)
24
  - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
 
 
25
  - [Italian](#italian)
26
  - [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
27
  - [Japanese](#japanese)
28
  - [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
29
- - [Hindi](#hindi)
30
- - [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
31
  - [Mandarin](#mandarin)
32
  - [Spanish](#spanish)
33
  - [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
@@ -81,6 +81,23 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
81
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ## Italian
85
 
86
  #### F5-TTS Italian @ finetune @ it
@@ -110,21 +127,6 @@ MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
110
  VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
111
  ```
112
 
113
- ## Hindi
114
-
115
- #### F5-TTS Small @ pretrain @ hi
116
- |Model|🤗Hugging Face|Data (Hours)|Model License|
117
- |:---:|:------------:|:-----------:|:-------------:|
118
- |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
119
-
120
- ```bash
121
- MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
122
- VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
123
- ```
124
-
125
- Authors: SPRING Lab, Indian Institute of Technology, Madras
126
- <br>
127
- Website: https://asr.iitm.ac.in/
128
 
129
  ## Mandarin
130
 
 
22
  - [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
23
  - [French](#french)
24
  - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
25
+ - [Hindi](#hindi)
26
+ - [F5-TTS Small @ pretrain @ hi](#f5-tts-small--pretrain--hi)
27
  - [Italian](#italian)
28
  - [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
29
  - [Japanese](#japanese)
30
  - [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
 
 
31
  - [Mandarin](#mandarin)
32
  - [Spanish](#spanish)
33
  - [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
 
81
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
82
 
83
 
84
+ ## Hindi
85
+
86
+ #### F5-TTS Small @ pretrain @ hi
87
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
88
+ |:---:|:------------:|:-----------:|:-------------:|
89
+ |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
90
+
91
+ ```bash
92
+ MODEL_CKPT: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
93
+ VOCAB_FILE: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
94
+ ```
95
+
96
+ Authors: SPRING Lab, Indian Institute of Technology, Madras
97
+ <br>
98
+ Website: https://asr.iitm.ac.in/
99
+
100
+
101
  ## Italian
102
 
103
  #### F5-TTS Italian @ finetune @ it
 
127
  VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
128
  ```
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  ## Mandarin
132
 
src/f5_tts/infer/examples/basic/basic.toml CHANGED
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
- output_file = "infer_cli_out.wav"
 
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
+ output_file = "infer_cli_basic.wav"
src/f5_tts/infer/examples/multi/story.toml CHANGED
@@ -8,6 +8,7 @@ gen_text = ""
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
 
11
 
12
  [voices.town]
13
  ref_audio = "infer/examples/multi/town.flac"
 
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
11
+ output_file = "infer_cli_story.wav"
12
 
13
  [voices.town]
14
  ref_audio = "infer/examples/multi/town.flac"
src/f5_tts/infer/infer_cli.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import codecs
3
  import os
4
  import re
 
5
  from importlib.resources import files
6
  from pathlib import Path
7
 
@@ -11,6 +12,14 @@ import tomli
11
  from cached_path import cached_path
12
 
13
  from f5_tts.infer.utils_infer import (
 
 
 
 
 
 
 
 
14
  infer_process,
15
  load_model,
16
  load_vocoder,
@@ -19,6 +28,7 @@ from f5_tts.infer.utils_infer import (
19
  )
20
  from f5_tts.model import DiT, UNetT
21
 
 
22
  parser = argparse.ArgumentParser(
23
  prog="python3 infer-cli.py",
24
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
@@ -27,86 +37,161 @@ parser = argparse.ArgumentParser(
27
  parser.add_argument(
28
  "-c",
29
  "--config",
30
- help="Configuration file. Default=infer/examples/basic/basic.toml",
31
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
 
32
  )
 
 
 
 
33
  parser.add_argument(
34
  "-m",
35
  "--model",
36
- help="F5-TTS | E2-TTS",
 
37
  )
38
  parser.add_argument(
39
  "-p",
40
  "--ckpt_file",
41
- help="The Checkpoint .pt",
 
42
  )
43
  parser.add_argument(
44
  "-v",
45
  "--vocab_file",
46
- help="The vocab .txt",
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
  parser.add_argument(
51
  "-t",
52
  "--gen_text",
53
  type=str,
54
- help="Text to generate.",
55
  )
56
  parser.add_argument(
57
  "-f",
58
  "--gen_file",
59
  type=str,
60
- help="File with text to generate. Ignores --gen_text",
61
  )
62
  parser.add_argument(
63
  "-o",
64
  "--output_dir",
65
  type=str,
66
- help="Path to output folder..",
67
  )
68
  parser.add_argument(
69
  "-w",
70
  "--output_file",
71
  type=str,
72
- help="Filename of output file..",
73
  )
74
  parser.add_argument(
75
  "--save_chunk",
76
  action="store_true",
77
- help="Save chunk audio if your text is too long.",
78
  )
79
  parser.add_argument(
80
  "--remove_silence",
81
- help="Remove silence.",
 
82
  )
83
- parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
84
  parser.add_argument(
85
  "--load_vocoder_from_local",
86
  action="store_true",
87
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
88
  )
89
  parser.add_argument(
90
- "--speed",
 
 
 
 
 
 
 
 
 
 
 
91
  type=float,
92
- default=1.0,
93
- help="Adjust the speed of the audio generation (default: 1.0)",
94
  )
95
  parser.add_argument(
96
  "--nfe_step",
97
  type=int,
98
- default=32,
99
- help="Set the number of denoising steps (default: 32)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
  args = parser.parse_args()
102
 
 
 
 
103
  config = tomli.load(open(args.config, "rb"))
104
 
105
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
106
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
107
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
108
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
109
- save_chunk = args.save_chunk if args.save_chunk else False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # patches for pip pkg user
112
  if "infer/examples/" in ref_audio:
@@ -119,35 +204,39 @@ if "voices" in config:
119
  if "infer/examples/" in voice_ref_audio:
120
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
121
 
 
 
 
122
  if gen_file:
123
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
124
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
125
- output_file = args.output_file if args.output_file else config["output_file"]
126
- model = args.model if args.model else config["model"]
127
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
128
- vocab_file = args.vocab_file if args.vocab_file else ""
129
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
130
- speed = args.speed
131
- nfe_step = args.nfe_step
132
 
133
  wave_path = Path(output_dir) / output_file
134
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
 
 
 
 
 
 
 
135
 
136
- vocoder_name = args.vocoder_name
137
- mel_spec_type = args.vocoder_name
138
  if vocoder_name == "vocos":
139
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
140
  elif vocoder_name == "bigvgan":
141
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
142
 
143
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
 
144
 
 
145
 
146
- # load models
147
  if model == "F5-TTS":
148
  model_cls = DiT
149
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
150
- if ckpt_file == "":
151
  if vocoder_name == "vocos":
152
  repo_name = "F5-TTS"
153
  exp_name = "F5TTS_Base"
@@ -164,19 +253,21 @@ elif model == "E2-TTS":
164
  assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
165
  model_cls = UNetT
166
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
167
- if ckpt_file == "":
168
  repo_name = "E2-TTS"
169
  exp_name = "E2TTS_Base"
170
  ckpt_step = 1200000
171
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
172
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
173
 
174
-
175
  print(f"Using {model}...")
176
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
 
177
 
 
178
 
179
- def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
 
180
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
181
  if "voices" not in config:
182
  voices = {"main": main_voice}
@@ -184,16 +275,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
184
  voices = config["voices"]
185
  voices["main"] = main_voice
186
  for voice in voices:
 
 
187
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
188
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
189
  )
190
- print("Voice:", voice)
191
- print("Ref_audio:", voices[voice]["ref_audio"])
192
- print("Ref_text:", voices[voice]["ref_text"])
193
 
194
  generated_audio_segments = []
195
  reg1 = r"(?=\[\w+\])"
196
- chunks = re.split(reg1, text_gen)
197
  reg2 = r"\[(\w+)\]"
198
  for text in chunks:
199
  if not text.strip():
@@ -208,21 +299,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
208
  print(f"Voice {voice} not found, using main.")
209
  voice = "main"
210
  text = re.sub(reg2, "", text)
211
- gen_text = text.strip()
212
- ref_audio = voices[voice]["ref_audio"]
213
- ref_text = voices[voice]["ref_text"]
214
  print(f"Voice: {voice}")
215
- audio, final_sample_rate, spectragram = infer_process(
216
- ref_audio,
217
- ref_text,
218
- gen_text,
219
- model_obj,
220
  vocoder,
221
- mel_spec_type=mel_spec_type,
222
- speed=speed,
 
223
  nfe_step=nfe_step,
 
 
 
 
224
  )
225
- generated_audio_segments.append(audio)
 
 
 
 
 
 
 
 
 
226
 
227
  if generated_audio_segments:
228
  final_wave = np.concatenate(generated_audio_segments)
@@ -236,22 +341,6 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
236
  if remove_silence:
237
  remove_silence_for_generated_wav(f.name)
238
  print(f.name)
239
- # Ensure the gen_text chunk directory exists
240
-
241
- if save_chunk:
242
- gen_text_chunk_dir = os.path.join(output_dir, "chunks")
243
- if not os.path.exists(gen_text_chunk_dir): # if Not create directory
244
- os.makedirs(gen_text_chunk_dir)
245
-
246
- # Save individual chunks as separate files
247
- for idx, segment in enumerate(generated_audio_segments):
248
- gen_text_chunk_path = os.path.join(output_dir, gen_text_chunk_dir, f"chunk_{idx}.wav")
249
- sf.write(gen_text_chunk_path, segment, final_sample_rate)
250
- print(f"Saved gen_text chunk {idx} at {gen_text_chunk_path}")
251
-
252
-
253
- def main():
254
- main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
255
 
256
 
257
  if __name__ == "__main__":
 
2
  import codecs
3
  import os
4
  import re
5
+ from datetime import datetime
6
  from importlib.resources import files
7
  from pathlib import Path
8
 
 
12
  from cached_path import cached_path
13
 
14
  from f5_tts.infer.utils_infer import (
15
+ mel_spec_type,
16
+ target_rms,
17
+ cross_fade_duration,
18
+ nfe_step,
19
+ cfg_strength,
20
+ sway_sampling_coef,
21
+ speed,
22
+ fix_duration,
23
  infer_process,
24
  load_model,
25
  load_vocoder,
 
28
  )
29
  from f5_tts.model import DiT, UNetT
30
 
31
+
32
  parser = argparse.ArgumentParser(
33
  prog="python3 infer-cli.py",
34
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
 
37
  parser.add_argument(
38
  "-c",
39
  "--config",
40
+ type=str,
41
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
42
+ help="The configuration file, default see infer/examples/basic/basic.toml",
43
  )
44
+
45
+
46
+ # Note. Not to provide default value here in order to read default from config file
47
+
48
  parser.add_argument(
49
  "-m",
50
  "--model",
51
+ type=str,
52
+ help="The model name: F5-TTS | E2-TTS",
53
  )
54
  parser.add_argument(
55
  "-p",
56
  "--ckpt_file",
57
+ type=str,
58
+ help="The path to model checkpoint .pt, leave blank to use default",
59
  )
60
  parser.add_argument(
61
  "-v",
62
  "--vocab_file",
63
+ type=str,
64
+ help="The path to vocab file .txt, leave blank to use default",
65
+ )
66
+ parser.add_argument(
67
+ "-r",
68
+ "--ref_audio",
69
+ type=str,
70
+ help="The reference audio file.",
71
+ )
72
+ parser.add_argument(
73
+ "-s",
74
+ "--ref_text",
75
+ type=str,
76
+ help="The transcript/subtitle for the reference audio",
77
  )
 
 
78
  parser.add_argument(
79
  "-t",
80
  "--gen_text",
81
  type=str,
82
+ help="The text to make model synthesize a speech",
83
  )
84
  parser.add_argument(
85
  "-f",
86
  "--gen_file",
87
  type=str,
88
+ help="The file with text to generate, will ignore --gen_text",
89
  )
90
  parser.add_argument(
91
  "-o",
92
  "--output_dir",
93
  type=str,
94
+ help="The path to output folder",
95
  )
96
  parser.add_argument(
97
  "-w",
98
  "--output_file",
99
  type=str,
100
+ help="The name of output file",
101
  )
102
  parser.add_argument(
103
  "--save_chunk",
104
  action="store_true",
105
+ help="To save each audio chunks during inference",
106
  )
107
  parser.add_argument(
108
  "--remove_silence",
109
+ action="store_true",
110
+ help="To remove long silence found in ouput",
111
  )
 
112
  parser.add_argument(
113
  "--load_vocoder_from_local",
114
  action="store_true",
115
+ help="To load vocoder from local dir, default to ../checkpoints/charactr/vocos-mel-24khz",
116
  )
117
  parser.add_argument(
118
+ "--vocoder_name",
119
+ type=str,
120
+ choices=["vocos", "bigvgan"],
121
+ help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
122
+ )
123
+ parser.add_argument(
124
+ "--target_rms",
125
+ type=float,
126
+ help=f"Target output speech loudness normalization value, default {target_rms}",
127
+ )
128
+ parser.add_argument(
129
+ "--cross_fade_duration",
130
  type=float,
131
+ help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
 
132
  )
133
  parser.add_argument(
134
  "--nfe_step",
135
  type=int,
136
+ help=f"The number of function evaluation (denoising steps), default {nfe_step}",
137
+ )
138
+ parser.add_argument(
139
+ "--cfg_strength",
140
+ type=float,
141
+ help=f"Classifier-free guidance strength, default {cfg_strength}",
142
+ )
143
+ parser.add_argument(
144
+ "--sway_sampling_coef",
145
+ type=float,
146
+ help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
147
+ )
148
+ parser.add_argument(
149
+ "--speed",
150
+ type=float,
151
+ help=f"The speed of the generated audio, default {speed}",
152
+ )
153
+ parser.add_argument(
154
+ "--fix_duration",
155
+ type=float,
156
+ help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
157
  )
158
  args = parser.parse_args()
159
 
160
+
161
+ # config file
162
+
163
  config = tomli.load(open(args.config, "rb"))
164
 
165
+
166
+ # command-line interface parameters
167
+
168
+ model = args.model or config.get("model", "F5-TTS")
169
+ ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
170
+ vocab_file = args.vocab_file or config.get("vocab_file", "")
171
+
172
+ ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
173
+ ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
174
+ gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
175
+ gen_file = args.gen_file or config.get("gen_file", "")
176
+
177
+ output_dir = args.output_dir or config.get("output_dir", "tests")
178
+ output_file = args.output_file or config.get(
179
+ "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
180
+ )
181
+
182
+ save_chunk = args.save_chunk
183
+ remove_silence = args.remove_silence
184
+ load_vocoder_from_local = args.load_vocoder_from_local
185
+
186
+ vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
187
+ target_rms = args.target_rms or config.get("target_rms", target_rms)
188
+ cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
189
+ nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
190
+ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
191
+ sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
192
+ speed = args.speed or config.get("speed", speed)
193
+ fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
194
+
195
 
196
  # patches for pip pkg user
197
  if "infer/examples/" in ref_audio:
 
204
  if "infer/examples/" in voice_ref_audio:
205
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
206
 
207
+
208
+ # ignore gen_text if gen_file provided
209
+
210
  if gen_file:
211
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
212
+
213
+
214
+ # output path
 
 
 
 
 
215
 
216
  wave_path = Path(output_dir) / output_file
217
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
218
+ if save_chunk:
219
+ output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
220
+ if not os.path.exists(output_chunk_dir):
221
+ os.makedirs(output_chunk_dir)
222
+
223
+
224
+ # load vocoder
225
 
 
 
226
  if vocoder_name == "vocos":
227
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
228
  elif vocoder_name == "bigvgan":
229
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
230
 
231
+ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
232
+
233
 
234
+ # load TTS model
235
 
 
236
  if model == "F5-TTS":
237
  model_cls = DiT
238
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
239
+ if not ckpt_file: # path not specified, download from repo
240
  if vocoder_name == "vocos":
241
  repo_name = "F5-TTS"
242
  exp_name = "F5TTS_Base"
 
253
  assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
254
  model_cls = UNetT
255
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
256
+ if not ckpt_file: # path not specified, download from repo
257
  repo_name = "E2-TTS"
258
  exp_name = "E2TTS_Base"
259
  ckpt_step = 1200000
260
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
261
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
262
 
 
263
  print(f"Using {model}...")
264
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
265
+
266
 
267
+ # inference process
268
 
269
+
270
+ def main():
271
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
272
  if "voices" not in config:
273
  voices = {"main": main_voice}
 
275
  voices = config["voices"]
276
  voices["main"] = main_voice
277
  for voice in voices:
278
+ print("Voice:", voice)
279
+ print("ref_audio ", voices[voice]["ref_audio"])
280
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
281
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
282
  )
283
+ print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
 
 
284
 
285
  generated_audio_segments = []
286
  reg1 = r"(?=\[\w+\])"
287
+ chunks = re.split(reg1, gen_text)
288
  reg2 = r"\[(\w+)\]"
289
  for text in chunks:
290
  if not text.strip():
 
299
  print(f"Voice {voice} not found, using main.")
300
  voice = "main"
301
  text = re.sub(reg2, "", text)
302
+ ref_audio_ = voices[voice]["ref_audio"]
303
+ ref_text_ = voices[voice]["ref_text"]
304
+ gen_text_ = text.strip()
305
  print(f"Voice: {voice}")
306
+ audio_segment, final_sample_rate, spectragram = infer_process(
307
+ ref_audio_,
308
+ ref_text_,
309
+ gen_text_,
310
+ ema_model,
311
  vocoder,
312
+ mel_spec_type=vocoder_name,
313
+ target_rms=target_rms,
314
+ cross_fade_duration=cross_fade_duration,
315
  nfe_step=nfe_step,
316
+ cfg_strength=cfg_strength,
317
+ sway_sampling_coef=sway_sampling_coef,
318
+ speed=speed,
319
+ fix_duration=fix_duration,
320
  )
321
+ generated_audio_segments.append(audio_segment)
322
+
323
+ if save_chunk:
324
+ if len(gen_text_) > 200:
325
+ gen_text_ = gen_text_[:200] + " ... "
326
+ sf.write(
327
+ os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
328
+ audio_segment,
329
+ final_sample_rate,
330
+ )
331
 
332
  if generated_audio_segments:
333
  final_wave = np.concatenate(generated_audio_segments)
 
341
  if remove_silence:
342
  remove_silence_for_generated_wav(f.name)
343
  print(f.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
 
346
  if __name__ == "__main__":