SWivid commited on
Commit
ba4b04b
·
1 Parent(s): 254e5e6

finish eval dependencies; update infer_gradio with chat feature

Browse files
README.md CHANGED
@@ -81,6 +81,9 @@ python scripts/prepare_emilia.py
81
 
82
  # Prepare the Wenetspeech4TTS dataset
83
  python scripts/prepare_wenetspeech4tts.py
 
 
 
84
  ```
85
 
86
  ## Training & Finetuning
@@ -175,6 +178,7 @@ python inference-cli.py \
175
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
176
 
177
  # Multi voice
 
178
  python inference-cli.py -c samples/story.toml
179
  ```
180
 
@@ -211,54 +215,7 @@ To test speech editing capabilities, use the following command.
211
  python f5_tts/speech_edit.py
212
  ```
213
 
214
- ## Evaluation
215
-
216
- ### Prepare Test Datasets
217
-
218
- 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
219
- 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
220
- 3. Unzip the downloaded datasets and place them in the data/ directory.
221
- 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
222
- 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
223
-
224
- ### Batch Inference for Test Set
225
-
226
- To run batch inference for evaluations, execute the following commands:
227
-
228
- ```bash
229
- # switch to the main directory
230
- cd f5_tts
231
-
232
- # batch inference for evaluations
233
- accelerate config # if not set before
234
- bash scripts/eval_infer_batch.sh
235
- ```
236
-
237
- ### Download Evaluation Model Checkpoints
238
-
239
- 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
240
- 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
241
- 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
242
-
243
- ### Objective Evaluation
244
-
245
- Install packages for evaluation:
246
-
247
- ```bash
248
- pip install -e .[eval]
249
- ```
250
-
251
- Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
252
- ```bash
253
- # switch to the main directory
254
- cd f5_tts
255
-
256
- # Evaluation for Seed-TTS test set
257
- python scripts/eval_seedtts_testset.py
258
-
259
- # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
260
- python scripts/eval_librispeech_test_clean.py
261
- ```
262
 
263
  ## Acknowledgements
264
 
 
81
 
82
  # Prepare the Wenetspeech4TTS dataset
83
  python scripts/prepare_wenetspeech4tts.py
84
+
85
+ # https://github.com/SWivid/F5-TTS/discussions/57#discussioncomment-10959029
86
+ python scripts/prepare_csv_wavs.py
87
  ```
88
 
89
  ## Training & Finetuning
 
178
  --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
179
 
180
  # Multi voice
181
+ # https://github.com/SWivid/F5-TTS/pull/146#issue-2595207852
182
  python inference-cli.py -c samples/story.toml
183
  ```
184
 
 
215
  python f5_tts/speech_edit.py
216
  ```
217
 
218
+ ## [Evaluation](src/f5_tts/eval/README.md)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  ## Acknowledgements
221
 
pyproject.toml CHANGED
@@ -46,6 +46,7 @@ eval = [
46
  "faster_whisper==0.10.1",
47
  "funasr",
48
  "jiwer",
 
49
  "zhconv",
50
  "zhon",
51
  ]
 
46
  "faster_whisper==0.10.1",
47
  "funasr",
48
  "jiwer",
49
+ "modelscope",
50
  "zhconv",
51
  "zhon",
52
  ]
src/f5_tts/eval/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Evaluation
3
+
4
+ Install packages for evaluation:
5
+
6
+ ```bash
7
+ pip install -e .[eval]
8
+ ```
9
+
10
+ ### Prepare Test Datasets
11
+
12
+ 1. *Seed-TTS testset*: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
13
+ 2. *LibriSpeech test-clean*: Download from [OpenSLR](http://www.openslr.org/12/).
14
+ 3. Unzip the downloaded datasets and place them in the `data/` directory.
15
+ 4. Update the path for *LibriSpeech test-clean* data in `src/f5_tts/eval/eval_infer_batch.py`
16
+ 5. Our filtered LibriSpeech-PC 4-10s subset: `data/librispeech_pc_test_clean_cross_sentence.lst`
17
+
18
+ ### Batch Inference for Test Set
19
+
20
+ To run batch inference for evaluations, execute the following commands:
21
+
22
+ ```bash
23
+ # batch inference for evaluations
24
+ accelerate config # if not set before
25
+ bash src/f5_tts/eval/eval_infer_batch.sh
26
+ ```
27
+
28
+ ### Download Evaluation Model Checkpoints
29
+
30
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
31
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
32
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
33
+
34
+ Then update in the following scripts with the paths you put evaluation model ckpts to.
35
+
36
+ ### Objective Evaluation
37
+
38
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
39
+ ```bash
40
+ # Evaluation for Seed-TTS test set
41
+ python src/f5_tts/eval/eval_seedtts_testset.py
42
+
43
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
44
+ python src/f5_tts/eval/eval_librispeech_test_clean.py
45
+ ```
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -14,9 +14,9 @@ from accelerate import Accelerator
14
  from vocos import Vocos
15
 
16
  from f5_tts.model import CFM, UNetT, DiT
17
- from f5_tts.model.utils import (
18
- load_checkpoint,
19
- get_tokenizer,
20
  get_seedtts_testset_metainfo,
21
  get_librispeech_test_clean_metainfo,
22
  get_inference_prompt,
@@ -34,6 +34,7 @@ hop_length = 256
34
  target_rms = 0.1
35
 
36
  tokenizer = "pinyin"
 
37
 
38
 
39
  def main():
@@ -58,7 +59,7 @@ def main():
58
  dataset_name = args.dataset
59
  exp_name = args.expname
60
  ckpt_step = args.ckptstep
61
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
62
 
63
  nfe_step = args.nfestep
64
  ode_method = args.odemethod
@@ -80,23 +81,22 @@ def main():
80
  model_cls = UNetT
81
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
82
 
83
- datapath = files("f5_tts").joinpath("data")
84
-
85
  if testset == "ls_pc_test_clean":
86
- metalst = os.path.join(datapath, "librispeech_pc_test_clean_cross_sentence.lst")
87
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
 
90
  elif testset == "seedtts_test_zh":
91
- metalst = os.path.join(datapath, "seedtts_testset/zh/meta.lst")
92
  metainfo = get_seedtts_testset_metainfo(metalst)
93
 
94
  elif testset == "seedtts_test_en":
95
- metalst = os.path.join(datapath, "seedtts_testset/en/meta.lst")
96
  metainfo = get_seedtts_testset_metainfo(metalst)
97
 
98
  # path to save genereted wavs
99
  output_dir = (
 
100
  f"results/{exp_name}_{ckpt_step}/{testset}/"
101
  f"seed{seed}_{ode_method}_nfe{nfe_step}"
102
  f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
 
14
  from vocos import Vocos
15
 
16
  from f5_tts.model import CFM, UNetT, DiT
17
+ from f5_tts.model.utils import get_tokenizer
18
+ from f5_tts.infer.utils_infer import load_checkpoint
19
+ from f5_tts.eval.utils_eval import (
20
  get_seedtts_testset_metainfo,
21
  get_librispeech_test_clean_metainfo,
22
  get_inference_prompt,
 
34
  target_rms = 0.1
35
 
36
  tokenizer = "pinyin"
37
+ rel_path = str(files("f5_tts").joinpath("../../"))
38
 
39
 
40
  def main():
 
59
  dataset_name = args.dataset
60
  exp_name = args.expname
61
  ckpt_step = args.ckptstep
62
+ ckpt_path = rel_path + f"/ckpts/{exp_name}/model_{ckpt_step}.pt"
63
 
64
  nfe_step = args.nfestep
65
  ode_method = args.odemethod
 
81
  model_cls = UNetT
82
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
83
 
 
 
84
  if testset == "ls_pc_test_clean":
85
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
86
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
87
  metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
88
 
89
  elif testset == "seedtts_test_zh":
90
+ metalst = rel_path + "/data/seedtts_testset/zh/meta.lst"
91
  metainfo = get_seedtts_testset_metainfo(metalst)
92
 
93
  elif testset == "seedtts_test_en":
94
+ metalst = rel_path + "/data/seedtts_testset/en/meta.lst"
95
  metainfo = get_seedtts_testset_metainfo(metalst)
96
 
97
  # path to save genereted wavs
98
  output_dir = (
99
+ f"{rel_path}/"
100
  f"results/{exp_name}_{ckpt_step}/{testset}/"
101
  f"seed{seed}_{ode_method}_nfe{nfe_step}"
102
  f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}"
src/f5_tts/eval/eval_infer_batch.sh CHANGED
@@ -1,13 +1,13 @@
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
 
13
  # etc.
 
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch src/f5_tts/eval/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
 
13
  # etc.
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -6,18 +6,22 @@ import os
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
 
 
9
  import numpy as np
10
 
11
- from f5_tts.model.utils import (
12
  get_librispeech_test,
13
  run_asr_wer,
14
  run_sim,
15
  )
16
 
 
 
17
 
18
  eval_task = "wer" # sim | wer
19
  lang = "en"
20
- metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
21
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
22
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
 
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
9
+ from importlib.resources import files
10
+
11
  import numpy as np
12
 
13
+ from f5_tts.eval.utils_eval import (
14
  get_librispeech_test,
15
  run_asr_wer,
16
  run_sim,
17
  )
18
 
19
+ rel_path = str(files("f5_tts").joinpath("../../"))
20
+
21
 
22
  eval_task = "wer" # sim | wer
23
  lang = "en"
24
+ metalst = rel_path + "/data/librispeech_pc_test_clean_cross_sentence.lst"
25
  librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
26
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
 
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -6,19 +6,23 @@ import os
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
 
 
9
  import numpy as np
10
 
11
- from f5_tts.model.utils import (
12
  get_seed_tts_test,
13
  run_asr_wer,
14
  run_sim,
15
  )
16
 
 
 
17
 
18
  eval_task = "wer" # sim | wer
19
  lang = "zh" # zh | en
20
- metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
21
- # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
22
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
23
 
24
 
 
6
  sys.path.append(os.getcwd())
7
 
8
  import multiprocessing as mp
9
+ from importlib.resources import files
10
+
11
  import numpy as np
12
 
13
+ from f5_tts.eval.utils_eval import (
14
  get_seed_tts_test,
15
  run_asr_wer,
16
  run_sim,
17
  )
18
 
19
+ rel_path = str(files("f5_tts").joinpath("../../"))
20
+
21
 
22
  eval_task = "wer" # sim | wer
23
  lang = "zh" # zh | en
24
+ metalst = rel_path + f"/data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
25
+ # gen_wav_dir = rel_path + f"/data/seedtts_testset/{lang}/wavs" # ground truth wavs
26
  gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
27
 
28
 
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -52,13 +52,11 @@ E2TTS_ema_model = load_model(
52
  UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
53
  )
54
 
55
- # Initialize Qwen model and tokenizer
56
- model_name = "Qwen/Qwen2.5-3B-Instruct"
57
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
58
- tokenizer = AutoTokenizer.from_pretrained(model_name)
59
 
60
 
61
- def generate_response(messages):
62
  """Generate response using Qwen"""
63
  text = tokenizer.apply_chat_template(
64
  messages,
@@ -525,137 +523,157 @@ with gr.Blocks() as app_chat:
525
  # Voice Chat
526
  Have a conversation with an AI using your reference voice!
527
  1. Upload a reference audio clip and optionally its transcript.
528
- 2. Record your message through your microphone.
529
- 3. The AI will respond using the reference voice.
 
530
  """
531
  )
532
 
533
- with gr.Row():
534
- with gr.Column():
535
- ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
536
-
537
- with gr.Column():
538
- with gr.Accordion("Advanced Settings", open=False):
539
- model_choice_chat = gr.Radio(
540
- choices=["F5-TTS", "E2-TTS"],
541
- label="TTS Model",
542
- value="F5-TTS",
543
- )
544
- remove_silence_chat = gr.Checkbox(
545
- label="Remove Silences",
546
- value=True,
547
- )
548
- ref_text_chat = gr.Textbox(
549
- label="Reference Text",
550
- info="Optional: Leave blank to auto-transcribe",
551
- lines=2,
552
- )
553
- system_prompt_chat = gr.Textbox(
554
- label="System Prompt",
555
- value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
556
- lines=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  )
558
 
559
- chatbot_interface = gr.Chatbot(label="Conversation")
560
-
561
- with gr.Row():
562
- with gr.Column():
563
- audio_output_chat = gr.Audio(autoplay=True)
564
- with gr.Column():
565
- audio_input_chat = gr.Microphone(
566
- label="Or speak your message",
567
- type="filepath",
568
- )
569
-
570
- clear_btn_chat = gr.Button("Clear Conversation")
571
-
572
- conversation_state = gr.State(
573
- value=[
574
- {
575
- "role": "system",
576
- "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
577
- }
578
- ]
579
- )
580
 
581
- def process_audio_input(audio_path, history, conv_state):
582
- """Handle audio input from user"""
583
- if not audio_path:
584
- return history, conv_state, ""
 
 
 
 
585
 
586
- text = ""
587
- text = preprocess_ref_audio_text(audio_path, text)[1]
 
 
 
588
 
589
- if not text.strip():
590
- return history, conv_state, ""
591
 
592
- conv_state.append({"role": "user", "content": text})
593
- history.append((text, None))
594
 
595
- response = generate_response(conv_state)
 
596
 
597
- conv_state.append({"role": "assistant", "content": response})
598
- history[-1] = (text, response)
599
 
600
- return history, conv_state, ""
 
601
 
602
- def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
603
- """Generate TTS audio for AI response"""
604
- if not history or not ref_audio:
605
- return None
606
-
607
- last_user_message, last_ai_response = history[-1]
608
- if not last_ai_response:
609
- return None
610
 
611
- audio_result, _ = infer(
612
- ref_audio,
613
- ref_text,
614
- last_ai_response,
615
- model,
616
- remove_silence,
617
- cross_fade_duration=0.15,
618
- speed=1.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  )
620
- return audio_result
621
-
622
- def clear_conversation():
623
- """Reset the conversation"""
624
- return [], [
625
- {
626
- "role": "system",
627
- "content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
628
- }
629
- ]
630
 
631
- def update_system_prompt(new_prompt):
632
- """Update the system prompt and reset the conversation"""
633
- new_conv_state = [{"role": "system", "content": new_prompt}]
634
- return [], new_conv_state
635
-
636
- # Handle audio input
637
- audio_input_chat.stop_recording(
638
- process_audio_input,
639
- inputs=[audio_input_chat, chatbot_interface, conversation_state],
640
- outputs=[chatbot_interface, conversation_state],
641
- ).then(
642
- generate_audio_response,
643
- inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
644
- outputs=audio_output_chat,
645
- )
646
-
647
- # Handle clear button
648
- clear_btn_chat.click(
649
- clear_conversation,
650
- outputs=[chatbot_interface, conversation_state],
651
- )
652
 
653
- # Handle system prompt change and reset conversation
654
- system_prompt_chat.change(
655
- update_system_prompt,
656
- inputs=system_prompt_chat,
657
- outputs=[chatbot_interface, conversation_state],
658
- )
659
 
660
 
661
  with gr.Blocks() as app:
 
52
  UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
53
  )
54
 
55
+ chat_model_state = None
56
+ chat_tokenizer_state = None
 
 
57
 
58
 
59
+ def generate_response(messages, model, tokenizer):
60
  """Generate response using Qwen"""
61
  text = tokenizer.apply_chat_template(
62
  messages,
 
523
  # Voice Chat
524
  Have a conversation with an AI using your reference voice!
525
  1. Upload a reference audio clip and optionally its transcript.
526
+ 2. Load the chat model.
527
+ 3. Record your message through your microphone.
528
+ 4. The AI will respond using the reference voice.
529
  """
530
  )
531
 
532
+ load_chat_model_btn = gr.Button("Load Chat Model", variant="primary")
533
+
534
+ chat_interface_container = gr.Column(visible=False)
535
+
536
+ def load_chat_model():
537
+ global chat_model_state, chat_tokenizer_state
538
+ if chat_model_state is None:
539
+ show_info = gr.Info
540
+ show_info("Loading chat model...")
541
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
542
+ chat_model_state = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
543
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(model_name)
544
+ show_info("Chat model loaded.")
545
+
546
+ return gr.update(visible=False), gr.update(visible=True)
547
+
548
+ load_chat_model_btn.click(load_chat_model, outputs=[load_chat_model_btn, chat_interface_container])
549
+
550
+ with chat_interface_container:
551
+ with gr.Row():
552
+ with gr.Column():
553
+ ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
554
+ with gr.Column():
555
+ with gr.Accordion("Advanced Settings", open=False):
556
+ model_choice_chat = gr.Radio(
557
+ choices=["F5-TTS", "E2-TTS"],
558
+ label="TTS Model",
559
+ value="F5-TTS",
560
+ )
561
+ remove_silence_chat = gr.Checkbox(
562
+ label="Remove Silences",
563
+ value=True,
564
+ )
565
+ ref_text_chat = gr.Textbox(
566
+ label="Reference Text",
567
+ info="Optional: Leave blank to auto-transcribe",
568
+ lines=2,
569
+ )
570
+ system_prompt_chat = gr.Textbox(
571
+ label="System Prompt",
572
+ value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
573
+ lines=2,
574
+ )
575
+
576
+ chatbot_interface = gr.Chatbot(label="Conversation")
577
+
578
+ with gr.Row():
579
+ with gr.Column():
580
+ audio_output_chat = gr.Audio(autoplay=True)
581
+ with gr.Column():
582
+ audio_input_chat = gr.Microphone(
583
+ label="Speak your message",
584
+ type="filepath",
585
  )
586
 
587
+ clear_btn_chat = gr.Button("Clear Conversation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
+ conversation_state = gr.State(
590
+ value=[
591
+ {
592
+ "role": "system",
593
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
594
+ }
595
+ ]
596
+ )
597
 
598
+ # Modify process_audio_input to use model and tokenizer from state
599
+ def process_audio_input(audio_path, history, conv_state):
600
+ """Handle audio input from user"""
601
+ if not audio_path:
602
+ return history, conv_state, ""
603
 
604
+ text = ""
605
+ text = preprocess_ref_audio_text(audio_path, text)[1]
606
 
607
+ if not text.strip():
608
+ return history, conv_state, ""
609
 
610
+ conv_state.append({"role": "user", "content": text})
611
+ history.append((text, None))
612
 
613
+ response = generate_response(conv_state, chat_model_state, chat_tokenizer_state)
 
614
 
615
+ conv_state.append({"role": "assistant", "content": response})
616
+ history[-1] = (text, response)
617
 
618
+ return history, conv_state, ""
 
 
 
 
 
 
 
619
 
620
+ def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
621
+ """Generate TTS audio for AI response"""
622
+ if not history or not ref_audio:
623
+ return None
624
+
625
+ last_user_message, last_ai_response = history[-1]
626
+ if not last_ai_response:
627
+ return None
628
+
629
+ audio_result, _ = infer(
630
+ ref_audio,
631
+ ref_text,
632
+ last_ai_response,
633
+ model,
634
+ remove_silence,
635
+ cross_fade_duration=0.15,
636
+ speed=1.0,
637
+ )
638
+ return audio_result
639
+
640
+ def clear_conversation():
641
+ """Reset the conversation"""
642
+ return [], [
643
+ {
644
+ "role": "system",
645
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
646
+ }
647
+ ]
648
+
649
+ def update_system_prompt(new_prompt):
650
+ """Update the system prompt and reset the conversation"""
651
+ new_conv_state = [{"role": "system", "content": new_prompt}]
652
+ return [], new_conv_state
653
+
654
+ # Handle audio input
655
+ audio_input_chat.stop_recording(
656
+ process_audio_input,
657
+ inputs=[audio_input_chat, chatbot_interface, conversation_state],
658
+ outputs=[chatbot_interface, conversation_state],
659
+ ).then(
660
+ generate_audio_response,
661
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
662
+ outputs=audio_output_chat,
663
  )
 
 
 
 
 
 
 
 
 
 
664
 
665
+ # Handle clear button
666
+ clear_btn_chat.click(
667
+ clear_conversation,
668
+ outputs=[chatbot_interface, conversation_state],
669
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
 
671
+ # Handle system prompt change and reset conversation
672
+ system_prompt_chat.change(
673
+ update_system_prompt,
674
+ inputs=system_prompt_chat,
675
+ outputs=[chatbot_interface, conversation_state],
676
+ )
677
 
678
 
679
  with gr.Blocks() as app:
src/f5_tts/infer/utils_infer.py CHANGED
@@ -1,6 +1,7 @@
1
  # A unified script for inference process
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
 
 
4
  import re
5
  import tempfile
6
 
@@ -23,6 +24,7 @@ from f5_tts.model.utils import (
23
  convert_char_to_pinyin,
24
  )
25
 
 
26
 
27
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
28
 
@@ -194,23 +196,36 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=
194
  aseg.export(f.name, format="wav")
195
  ref_audio = f.name
196
 
197
- if not ref_text.strip():
198
- global asr_pipe
199
- if asr_pipe is None:
200
- initialize_asr_pipeline(device=device)
201
- show_info("No reference text provided, transcribing reference audio...")
202
- ref_text = asr_pipe(
203
- ref_audio,
204
- chunk_length_s=30,
205
- batch_size=128,
206
- generate_kwargs={"task": "transcribe"},
207
- return_timestamps=False,
208
- )["text"].strip()
209
- show_info("Finished transcription")
210
  else:
211
- show_info("Using custom reference text...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- # Add the functionality to ensure it ends with ". "
214
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
215
  if ref_text.endswith("."):
216
  ref_text += " "
 
1
  # A unified script for inference process
2
  # Make adjustments inside functions, and consider both gradio and cli scripts if need to change func output format
3
 
4
+ import hashlib
5
  import re
6
  import tempfile
7
 
 
24
  convert_char_to_pinyin,
25
  )
26
 
27
+ _ref_audio_cache = {}
28
 
29
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
30
 
 
196
  aseg.export(f.name, format="wav")
197
  ref_audio = f.name
198
 
199
+ # Compute a hash of the reference audio file
200
+ with open(ref_audio, "rb") as audio_file:
201
+ audio_data = audio_file.read()
202
+ audio_hash = hashlib.md5(audio_data).hexdigest()
203
+
204
+ global _ref_audio_cache
205
+ if audio_hash in _ref_audio_cache:
206
+ # Use cached reference text
207
+ show_info("Using cached reference text...")
208
+ ref_text = _ref_audio_cache[audio_hash]
 
 
 
209
  else:
210
+ if not ref_text.strip():
211
+ global asr_pipe
212
+ if asr_pipe is None:
213
+ initialize_asr_pipeline(device=device)
214
+ show_info("No reference text provided, transcribing reference audio...")
215
+ ref_text = asr_pipe(
216
+ ref_audio,
217
+ chunk_length_s=30,
218
+ batch_size=128,
219
+ generate_kwargs={"task": "transcribe"},
220
+ return_timestamps=False,
221
+ )["text"].strip()
222
+ show_info("Finished transcription")
223
+ else:
224
+ show_info("Using custom reference text...")
225
+ # Cache the transcribed text
226
+ _ref_audio_cache[audio_hash] = ref_text
227
 
228
+ # Ensure ref_text ends with a proper sentence-ending punctuation
229
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
230
  if ref_text.endswith("."):
231
  ref_text += " "
src/f5_tts/model/utils.py CHANGED
@@ -2,8 +2,8 @@ from __future__ import annotations
2
 
3
  import os
4
  import random
5
- from importlib.resources import files
6
  from collections import defaultdict
 
7
 
8
  import torch
9
  from torch.nn.utils.rnn import pad_sequence
@@ -109,7 +109,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
109
  - if use "byte", set to 256 (unicode byte range)
110
  """
111
  if tokenizer in ["pinyin", "char"]:
112
- tokenizer_path = os.path.join(files("f5_tts").joinpath("data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
@@ -120,6 +120,7 @@ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
120
  elif tokenizer == "byte":
121
  vocab_char_map = None
122
  vocab_size = 256
 
123
  elif tokenizer == "custom":
124
  with open(dataset_name, "r", encoding="utf-8") as f:
125
  vocab_char_map = {}
 
2
 
3
  import os
4
  import random
 
5
  from collections import defaultdict
6
+ from importlib.resources import files
7
 
8
  import torch
9
  from torch.nn.utils.rnn import pad_sequence
 
109
  - if use "byte", set to 256 (unicode byte range)
110
  """
111
  if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
  with open(tokenizer_path, "r", encoding="utf-8") as f:
114
  vocab_char_map = {}
115
  for i, char in enumerate(f):
 
120
  elif tokenizer == "byte":
121
  vocab_char_map = None
122
  vocab_size = 256
123
+
124
  elif tokenizer == "custom":
125
  with open(dataset_name, "r", encoding="utf-8") as f:
126
  vocab_char_map = {}