SWivid commited on
Commit
254e5e6
·
1 Parent(s): b4abb3c

update finetune-cli -gradio

Browse files
README.md CHANGED
@@ -183,6 +183,7 @@ Currently supported features:
183
  - Chunk inference
184
  - Podcast Generation
185
  - Multiple Speech-Type Generation
 
186
 
187
  You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
188
 
 
183
  - Chunk inference
184
  - Podcast Generation
185
  - Multiple Speech-Type Generation
186
+ - Voice Chat powered by Qwen2.5-3B-Instruct
187
 
188
  You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may also use local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
189
 
pyproject.toml CHANGED
@@ -35,6 +35,7 @@ dependencies = [
35
  "torchdiffeq",
36
  "tqdm>=4.65.0",
37
  "transformers",
 
38
  "vocos",
39
  "wandb",
40
  "x_transformers>=1.31.14",
 
35
  "torchdiffeq",
36
  "tqdm>=4.65.0",
37
  "transformers",
38
+ "transformers_stream_generator",
39
  "vocos",
40
  "wandb",
41
  "x_transformers>=1.31.14",
src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  sys.path.append(os.getcwd())
5
 
6
  import time
7
- import random
8
  from tqdm import tqdm
9
  import argparse
10
  from importlib.resources import files
@@ -97,8 +96,6 @@ def main():
97
  metainfo = get_seedtts_testset_metainfo(metalst)
98
 
99
  # path to save genereted wavs
100
- if seed is None:
101
- seed = random.randint(-10000, 10000)
102
  output_dir = (
103
  f"results/{exp_name}_{ckpt_step}/{testset}/"
104
  f"seed{seed}_{ode_method}_nfe{nfe_step}"
 
4
  sys.path.append(os.getcwd())
5
 
6
  import time
 
7
  from tqdm import tqdm
8
  import argparse
9
  from importlib.resources import files
 
96
  metainfo = get_seedtts_testset_metainfo(metalst)
97
 
98
  # path to save genereted wavs
 
 
99
  output_dir = (
100
  f"results/{exp_name}_{ckpt_step}/{testset}/"
101
  f"seed{seed}_{ode_method}_nfe{nfe_step}"
src/f5_tts/eval/eval_infer_batch.sh CHANGED
@@ -1,13 +1,13 @@
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
- accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
- accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
 
13
  # etc.
 
1
  #!/bin/bash
2
 
3
  # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
 
8
  # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -s 0 -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
 
13
  # etc.
src/f5_tts/infer/infer_gradio.py CHANGED
@@ -11,6 +11,7 @@ import soundfile as sf
11
  import torchaudio
12
  from cached_path import cached_path
13
  from pydub import AudioSegment
 
14
 
15
  try:
16
  import spaces
@@ -51,6 +52,33 @@ E2TTS_ema_model = load_model(
51
  UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
52
  )
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @gpu_decorator
56
  def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
@@ -490,6 +518,146 @@ with gr.Blocks() as app_emotional:
490
  outputs=generate_emotional_btn,
491
  )
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  with gr.Blocks() as app:
494
  gr.Markdown(
495
  """
@@ -507,7 +675,10 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
507
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
508
  """
509
  )
510
- gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
 
 
 
511
 
512
 
513
  @click.command()
 
11
  import torchaudio
12
  from cached_path import cached_path
13
  from pydub import AudioSegment
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
 
16
  try:
17
  import spaces
 
52
  UNetT, E2TTS_model_cfg, str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
53
  )
54
 
55
+ # Initialize Qwen model and tokenizer
56
+ model_name = "Qwen/Qwen2.5-3B-Instruct"
57
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
58
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
59
+
60
+
61
+ def generate_response(messages):
62
+ """Generate response using Qwen"""
63
+ text = tokenizer.apply_chat_template(
64
+ messages,
65
+ tokenize=False,
66
+ add_generation_prompt=True,
67
+ )
68
+
69
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
70
+ generated_ids = model.generate(
71
+ **model_inputs,
72
+ max_new_tokens=512,
73
+ temperature=0.7,
74
+ top_p=0.95,
75
+ )
76
+
77
+ generated_ids = [
78
+ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
79
+ ]
80
+ return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
81
+
82
 
83
  @gpu_decorator
84
  def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1):
 
518
  outputs=generate_emotional_btn,
519
  )
520
 
521
+
522
+ with gr.Blocks() as app_chat:
523
+ gr.Markdown(
524
+ """
525
+ # Voice Chat
526
+ Have a conversation with an AI using your reference voice!
527
+ 1. Upload a reference audio clip and optionally its transcript.
528
+ 2. Record your message through your microphone.
529
+ 3. The AI will respond using the reference voice.
530
+ """
531
+ )
532
+
533
+ with gr.Row():
534
+ with gr.Column():
535
+ ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
536
+
537
+ with gr.Column():
538
+ with gr.Accordion("Advanced Settings", open=False):
539
+ model_choice_chat = gr.Radio(
540
+ choices=["F5-TTS", "E2-TTS"],
541
+ label="TTS Model",
542
+ value="F5-TTS",
543
+ )
544
+ remove_silence_chat = gr.Checkbox(
545
+ label="Remove Silences",
546
+ value=True,
547
+ )
548
+ ref_text_chat = gr.Textbox(
549
+ label="Reference Text",
550
+ info="Optional: Leave blank to auto-transcribe",
551
+ lines=2,
552
+ )
553
+ system_prompt_chat = gr.Textbox(
554
+ label="System Prompt",
555
+ value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
556
+ lines=2,
557
+ )
558
+
559
+ chatbot_interface = gr.Chatbot(label="Conversation")
560
+
561
+ with gr.Row():
562
+ with gr.Column():
563
+ audio_output_chat = gr.Audio(autoplay=True)
564
+ with gr.Column():
565
+ audio_input_chat = gr.Microphone(
566
+ label="Or speak your message",
567
+ type="filepath",
568
+ )
569
+
570
+ clear_btn_chat = gr.Button("Clear Conversation")
571
+
572
+ conversation_state = gr.State(
573
+ value=[
574
+ {
575
+ "role": "system",
576
+ "content": "You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
577
+ }
578
+ ]
579
+ )
580
+
581
+ def process_audio_input(audio_path, history, conv_state):
582
+ """Handle audio input from user"""
583
+ if not audio_path:
584
+ return history, conv_state, ""
585
+
586
+ text = ""
587
+ text = preprocess_ref_audio_text(audio_path, text)[1]
588
+
589
+ if not text.strip():
590
+ return history, conv_state, ""
591
+
592
+ conv_state.append({"role": "user", "content": text})
593
+ history.append((text, None))
594
+
595
+ response = generate_response(conv_state)
596
+
597
+ conv_state.append({"role": "assistant", "content": response})
598
+ history[-1] = (text, response)
599
+
600
+ return history, conv_state, ""
601
+
602
+ def generate_audio_response(history, ref_audio, ref_text, model, remove_silence):
603
+ """Generate TTS audio for AI response"""
604
+ if not history or not ref_audio:
605
+ return None
606
+
607
+ last_user_message, last_ai_response = history[-1]
608
+ if not last_ai_response:
609
+ return None
610
+
611
+ audio_result, _ = infer(
612
+ ref_audio,
613
+ ref_text,
614
+ last_ai_response,
615
+ model,
616
+ remove_silence,
617
+ cross_fade_duration=0.15,
618
+ speed=1.0,
619
+ )
620
+ return audio_result
621
+
622
+ def clear_conversation():
623
+ """Reset the conversation"""
624
+ return [], [
625
+ {
626
+ "role": "system",
627
+ "content": "You are a friendly person, and may impersonate whoever they address you as. Stay in character. Keep your responses concise since they will be spoken out loud.",
628
+ }
629
+ ]
630
+
631
+ def update_system_prompt(new_prompt):
632
+ """Update the system prompt and reset the conversation"""
633
+ new_conv_state = [{"role": "system", "content": new_prompt}]
634
+ return [], new_conv_state
635
+
636
+ # Handle audio input
637
+ audio_input_chat.stop_recording(
638
+ process_audio_input,
639
+ inputs=[audio_input_chat, chatbot_interface, conversation_state],
640
+ outputs=[chatbot_interface, conversation_state],
641
+ ).then(
642
+ generate_audio_response,
643
+ inputs=[chatbot_interface, ref_audio_chat, ref_text_chat, model_choice_chat, remove_silence_chat],
644
+ outputs=audio_output_chat,
645
+ )
646
+
647
+ # Handle clear button
648
+ clear_btn_chat.click(
649
+ clear_conversation,
650
+ outputs=[chatbot_interface, conversation_state],
651
+ )
652
+
653
+ # Handle system prompt change and reset conversation
654
+ system_prompt_chat.change(
655
+ update_system_prompt,
656
+ inputs=system_prompt_chat,
657
+ outputs=[chatbot_interface, conversation_state],
658
+ )
659
+
660
+
661
  with gr.Blocks() as app:
662
  gr.Markdown(
663
  """
 
675
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
676
  """
677
  )
678
+ gr.TabbedInterface(
679
+ [app_tts, app_podcast, app_emotional, app_chat, app_credits],
680
+ ["TTS", "Podcast", "Multi-Style", "Voice-Chat", "Credits"],
681
+ )
682
 
683
 
684
  @click.command()
src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -60,7 +60,7 @@ def read_audio_text_pairs(csv_file_path):
60
  audio_text_pairs = []
61
 
62
  parent = Path(csv_file_path).parent
63
- with open(csv_file_path, mode="r", newline="", encoding="utf-8") as csvfile:
64
  reader = csv.reader(csvfile, delimiter="|")
65
  next(reader) # Skip the header row
66
  for row in reader:
 
60
  audio_text_pairs = []
61
 
62
  parent = Path(csv_file_path).parent
63
+ with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
64
  reader = csv.reader(csvfile, delimiter="|")
65
  next(reader) # Skip the header row
66
  for row in reader:
src/f5_tts/train/finetune_cli.py CHANGED
@@ -15,26 +15,35 @@ hop_length = 256
15
 
16
  # -------------------------- Argument Parsing --------------------------- #
17
  def parse_args():
 
 
 
 
 
 
 
 
 
18
  parser = argparse.ArgumentParser(description="Train CFM Model")
19
 
20
  parser.add_argument(
21
  "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
22
  )
23
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
24
- parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
25
- parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
26
  parser.add_argument(
27
  "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
28
  )
29
- parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
30
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
31
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
32
  parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
33
- parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
34
- parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
35
- parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
36
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
37
-
38
  parser.add_argument(
39
  "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
40
  )
@@ -60,13 +69,19 @@ def main():
60
  model_cls = DiT
61
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
62
  if args.finetune:
63
- ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
 
 
 
64
  elif args.exp_name == "E2TTS_Base":
65
  wandb_resume_id = None
66
  model_cls = UNetT
67
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
  if args.finetune:
69
- ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
 
 
 
70
 
71
  if args.finetune:
72
  path_ckpt = os.path.join("ckpts", args.dataset_name)
@@ -118,6 +133,7 @@ def main():
118
  )
119
 
120
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
121
  trainer.train(
122
  train_dataset,
123
  resumable_with_seed=666, # seed for shuffling dataset
 
15
 
16
  # -------------------------- Argument Parsing --------------------------- #
17
  def parse_args():
18
+ # batch_size_per_gpu = 1000 settting for gpu 8GB
19
+ # batch_size_per_gpu = 1600 settting for gpu 12GB
20
+ # batch_size_per_gpu = 2000 settting for gpu 16GB
21
+ # batch_size_per_gpu = 3200 settting for gpu 24GB
22
+
23
+ # num_warmup_updates 10000 sample = 500
24
+
25
+ # change save_per_updates , last_per_steps what you need ,
26
+
27
  parser = argparse.ArgumentParser(description="Train CFM Model")
28
 
29
  parser.add_argument(
30
  "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
31
  )
32
  parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
33
+ parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for training")
34
+ parser.add_argument("--batch_size_per_gpu", type=int, default=3200, help="Batch size per GPU")
35
  parser.add_argument(
36
  "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
37
  )
38
+ parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
39
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
40
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
41
  parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
42
+ parser.add_argument("--num_warmup_updates", type=int, default=500, help="Warmup steps")
43
+ parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
44
+ parser.add_argument("--last_per_steps", type=int, default=20000, help="Save last checkpoint every X steps")
45
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
46
+ parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
47
  parser.add_argument(
48
  "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
49
  )
 
69
  model_cls = DiT
70
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
71
  if args.finetune:
72
+ if args.pretrain is None:
73
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
74
+ else:
75
+ ckpt_path = args.pretrain
76
  elif args.exp_name == "E2TTS_Base":
77
  wandb_resume_id = None
78
  model_cls = UNetT
79
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
80
  if args.finetune:
81
+ if args.pretrain is None:
82
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
83
+ else:
84
+ ckpt_path = args.pretrain
85
 
86
  if args.finetune:
87
  path_ckpt = os.path.join("ckpts", args.dataset_name)
 
133
  )
134
 
135
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
136
+
137
  trainer.train(
138
  train_dataset,
139
  resumable_with_seed=666, # seed for shuffling dataset
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -20,6 +20,7 @@ import torch
20
  import torchaudio
21
  from datasets import Dataset as Dataset_
22
  from datasets.arrow_writer import ArrowWriter
 
23
  from scipy.io import wavfile
24
  from transformers import pipeline
25
 
@@ -247,6 +248,9 @@ def start_training(
247
  save_per_updates=400,
248
  last_per_steps=800,
249
  finetune=True,
 
 
 
250
  ):
251
  global training_process, tts_api
252
 
@@ -256,7 +260,7 @@ def start_training(
256
  torch.cuda.empty_cache()
257
  tts_api = None
258
 
259
- path_project = os.path.join(path_data, dataset_name + "_pinyin")
260
 
261
  if not os.path.isdir(path_project):
262
  yield (
@@ -278,6 +282,7 @@ def start_training(
278
  yield "start train", gr.update(interactive=False), gr.update(interactive=False)
279
 
280
  # Command to run the training script with the specified arguments
 
281
  cmd = (
282
  f"accelerate launch finetune-cli.py --exp_name {exp_name} "
283
  f"--learning_rate {learning_rate} "
@@ -295,6 +300,13 @@ def start_training(
295
  if finetune:
296
  cmd += f" --finetune {finetune}"
297
 
 
 
 
 
 
 
 
298
  print(cmd)
299
 
300
  try:
@@ -331,10 +343,28 @@ def stop_training():
331
  return "train stop", gr.update(interactive=True), gr.update(interactive=False)
332
 
333
 
334
- def create_data_project(name):
335
- name += "_pinyin"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  os.makedirs(os.path.join(path_data, name), exist_ok=True)
337
  os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
 
 
338
 
339
 
340
  def transcribe(file_audio, language="english"):
@@ -359,14 +389,14 @@ def transcribe(file_audio, language="english"):
359
 
360
 
361
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
362
- name_project += "_pinyin"
363
  path_project = os.path.join(path_data, name_project)
364
  path_dataset = os.path.join(path_project, "dataset")
365
  path_project_wavs = os.path.join(path_project, "wavs")
366
  file_metadata = os.path.join(path_project, "metadata.csv")
367
 
368
- if audio_files is None:
369
- return "You need to load an audio file."
 
370
 
371
  if os.path.isdir(path_project_wavs):
372
  shutil.rmtree(path_project_wavs)
@@ -418,7 +448,7 @@ def transcribe_all(name_project, audio_files, language, user=False, progress=gr.
418
  except: # noqa: E722
419
  error_num += 1
420
 
421
- with open(file_metadata, "w", encoding="utf-8") as f:
422
  f.write(data)
423
 
424
  if error_num != []:
@@ -437,7 +467,6 @@ def format_seconds_to_hms(seconds):
437
 
438
 
439
  def create_metadata(name_project, progress=gr.Progress()):
440
- name_project += "_pinyin"
441
  path_project = os.path.join(path_data, name_project)
442
  path_project_wavs = os.path.join(path_project, "wavs")
443
  file_metadata = os.path.join(path_project, "metadata.csv")
@@ -448,7 +477,7 @@ def create_metadata(name_project, progress=gr.Progress()):
448
  if not os.path.isfile(file_metadata):
449
  return "The file was not found in " + file_metadata
450
 
451
- with open(file_metadata, "r", encoding="utf-8") as f:
452
  data = f.read()
453
 
454
  audio_path_list = []
@@ -499,7 +528,7 @@ def create_metadata(name_project, progress=gr.Progress()):
499
  for line in progress.tqdm(result, total=len(result), desc="prepare data"):
500
  writer.write(line)
501
 
502
- with open(file_duration, "w", encoding="utf-8") as f:
503
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
504
 
505
  file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
@@ -529,7 +558,6 @@ def calculate_train(
529
  last_per_steps,
530
  finetune,
531
  ):
532
- name_project += "_pinyin"
533
  path_project = os.path.join(path_data, name_project)
534
  file_duraction = os.path.join(path_project, "duration.json")
535
 
@@ -548,8 +576,8 @@ def calculate_train(
548
  data = json.load(file)
549
 
550
  duration_list = data["duration"]
551
-
552
  samples = len(duration_list)
 
553
 
554
  if torch.cuda.is_available():
555
  gpu_properties = torch.cuda.get_device_properties(0)
@@ -583,34 +611,67 @@ def calculate_train(
583
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
584
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  if finetune:
587
  learning_rate = 1e-5
588
  else:
589
  learning_rate = 7.5e-5
590
 
591
- return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
 
 
 
 
 
 
 
 
 
592
 
593
 
594
- def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
595
  try:
596
  checkpoint = torch.load(checkpoint_path)
597
  print("Original Checkpoint Keys:", checkpoint.keys())
598
 
599
  ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
 
 
600
 
601
- if ema_model_state_dict is not None:
 
 
 
 
602
  new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
603
  torch.save(new_checkpoint, new_checkpoint_path)
604
- return f"New checkpoint saved at: {new_checkpoint_path}"
605
- else:
606
- return "No 'ema_model_state_dict' found in the checkpoint."
607
 
608
  except Exception as e:
609
  return f"An error occurred: {e}"
610
 
611
 
612
  def vocab_check(project_name):
613
- name_project = project_name + "_pinyin"
614
  path_project = os.path.join(path_data, name_project)
615
 
616
  file_metadata = os.path.join(path_project, "metadata.csv")
@@ -619,15 +680,15 @@ def vocab_check(project_name):
619
  if not os.path.isfile(file_vocab):
620
  return f"the file {file_vocab} not found !"
621
 
622
- with open(file_vocab, "r", encoding="utf-8") as f:
623
  data = f.read()
624
-
625
- vocab = data.split("\n")
626
 
627
  if not os.path.isfile(file_metadata):
628
  return f"the file {file_metadata} not found !"
629
 
630
- with open(file_metadata, "r", encoding="utf-8") as f:
631
  data = f.read()
632
 
633
  miss_symbols = []
@@ -652,7 +713,7 @@ def vocab_check(project_name):
652
 
653
 
654
  def get_random_sample_prepare(project_name):
655
- name_project = project_name + "_pinyin"
656
  path_project = os.path.join(path_data, name_project)
657
  file_arrow = os.path.join(path_project, "raw.arrow")
658
  if not os.path.isfile(file_arrow):
@@ -665,14 +726,14 @@ def get_random_sample_prepare(project_name):
665
 
666
 
667
  def get_random_sample_transcribe(project_name):
668
- name_project = project_name + "_pinyin"
669
  path_project = os.path.join(path_data, name_project)
670
  file_metadata = os.path.join(path_project, "metadata.csv")
671
  if not os.path.isfile(file_metadata):
672
  return "", None
673
 
674
  data = ""
675
- with open(file_metadata, "r", encoding="utf-8") as f:
676
  data = f.read()
677
 
678
  list_data = []
@@ -703,13 +764,14 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
703
  global last_checkpoint, last_device, tts_api
704
 
705
  if not os.path.isfile(file_checkpoint):
706
- return None
707
 
708
  if training_process is not None:
709
  device_test = "cpu"
710
  else:
711
  device_test = None
712
 
 
713
  if last_checkpoint != file_checkpoint or last_device != device_test:
714
  if last_checkpoint != file_checkpoint:
715
  last_checkpoint = file_checkpoint
@@ -722,19 +784,67 @@ def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
722
 
723
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
724
  tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
725
- return f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
 
728
  with gr.Blocks() as app:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729
  with gr.Row():
 
 
730
  project_name = gr.Textbox(label="project name", value="my_speak")
731
  bt_create = gr.Button("create new project")
732
 
733
- bt_create.click(fn=create_data_project, inputs=[project_name])
 
 
734
 
735
  with gr.Tabs():
736
  with gr.TabItem("transcribe Data"):
737
- ch_manual = gr.Checkbox(label="user", value=False)
738
 
739
  mark_info_transcribe = gr.Markdown(
740
  """```plaintext
@@ -756,7 +866,7 @@ with gr.Blocks() as app:
756
  txt_info_transcribe = gr.Text(label="info", value="")
757
  bt_transcribe.click(
758
  fn=transcribe_all,
759
- inputs=[project_name, audio_speaker, txt_lang, ch_manual],
760
  outputs=[txt_info_transcribe],
761
  )
762
  ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
@@ -769,7 +879,7 @@ with gr.Blocks() as app:
769
 
770
  random_sample_transcribe.click(
771
  fn=get_random_sample_transcribe,
772
- inputs=[project_name],
773
  outputs=[random_text_transcribe, random_audio_transcribe],
774
  )
775
 
@@ -797,7 +907,7 @@ with gr.Blocks() as app:
797
 
798
  bt_prepare = bt_create = gr.Button("prepare")
799
  txt_info_prepare = gr.Text(label="info", value="")
800
- bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
801
 
802
  random_sample_prepare = gr.Button("random sample")
803
 
@@ -806,16 +916,20 @@ with gr.Blocks() as app:
806
  random_audio_prepare = gr.Audio(label="Audio", type="filepath")
807
 
808
  random_sample_prepare.click(
809
- fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
810
  )
811
 
812
  with gr.TabItem("train Data"):
813
  with gr.Row():
814
  bt_calculate = bt_create = gr.Button("Auto Settings")
815
- ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
816
  lb_samples = gr.Label(label="samples")
817
  batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
818
 
 
 
 
 
 
819
  with gr.Row():
820
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
821
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
@@ -844,7 +958,7 @@ with gr.Blocks() as app:
844
  start_button.click(
845
  fn=start_training,
846
  inputs=[
847
- project_name,
848
  exp_name,
849
  learning_rate,
850
  batch_size_per_gpu,
@@ -857,14 +971,18 @@ with gr.Blocks() as app:
857
  save_per_updates,
858
  last_per_steps,
859
  ch_finetune,
 
 
 
860
  ],
861
  outputs=[txt_info_train, start_button, stop_button],
862
  )
863
  stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
 
864
  bt_calculate.click(
865
  fn=calculate_train,
866
  inputs=[
867
- project_name,
868
  batch_size_type,
869
  max_samples,
870
  learning_rate,
@@ -881,29 +999,42 @@ with gr.Blocks() as app:
881
  last_per_steps,
882
  lb_samples,
883
  learning_rate,
 
884
  ],
885
  )
886
 
 
 
 
 
887
  with gr.TabItem("reduse checkpoint"):
888
  txt_path_checkpoint = gr.Text(label="path checkpoint :")
889
  txt_path_checkpoint_small = gr.Text(label="path output :")
 
890
  txt_info_reduse = gr.Text(label="info", value="")
891
  reduse_button = gr.Button("reduse")
892
  reduse_button.click(
893
  fn=extract_and_save_ema_model,
894
- inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
895
  outputs=[txt_info_reduse],
896
  )
897
 
898
  with gr.TabItem("vocab check experiment"):
899
  check_button = gr.Button("check vocab")
900
  txt_info_check = gr.Text(label="info", value="")
901
- check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
902
 
903
  with gr.TabItem("test model"):
904
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
 
 
905
  nfe_step = gr.Number(label="n_step", value=32)
906
- file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
 
 
 
 
 
907
 
908
  random_sample_infer = gr.Button("random sample")
909
 
@@ -911,17 +1042,24 @@ with gr.Blocks() as app:
911
  ref_audio = gr.Audio(label="audio ref", type="filepath")
912
  gen_text = gr.Textbox(label="gen text")
913
  random_sample_infer.click(
914
- fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
915
  )
916
- check_button_infer = gr.Button("infer")
 
 
 
 
917
  gen_audio = gr.Audio(label="audio gen", type="filepath")
918
 
919
  check_button_infer.click(
920
  fn=infer,
921
- inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
922
- outputs=[gen_audio],
923
  )
924
 
 
 
 
925
 
926
  @click.command()
927
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
20
  import torchaudio
21
  from datasets import Dataset as Dataset_
22
  from datasets.arrow_writer import ArrowWriter
23
+ from safetensors.torch import save_file
24
  from scipy.io import wavfile
25
  from transformers import pipeline
26
 
 
248
  save_per_updates=400,
249
  last_per_steps=800,
250
  finetune=True,
251
+ file_checkpoint_train="",
252
+ tokenizer_type="pinyin",
253
+ tokenizer_file="",
254
  ):
255
  global training_process, tts_api
256
 
 
260
  torch.cuda.empty_cache()
261
  tts_api = None
262
 
263
+ path_project = os.path.join(path_data, dataset_name)
264
 
265
  if not os.path.isdir(path_project):
266
  yield (
 
282
  yield "start train", gr.update(interactive=False), gr.update(interactive=False)
283
 
284
  # Command to run the training script with the specified arguments
285
+ dataset_name = dataset_name.replace("_pinyin", "").replace("_char", "")
286
  cmd = (
287
  f"accelerate launch finetune-cli.py --exp_name {exp_name} "
288
  f"--learning_rate {learning_rate} "
 
300
  if finetune:
301
  cmd += f" --finetune {finetune}"
302
 
303
+ if file_checkpoint_train != "":
304
+ cmd += f" --file_checkpoint_train {file_checkpoint_train}"
305
+
306
+ if tokenizer_file != "":
307
+ cmd += f" --tokenizer_path {tokenizer_file}"
308
+ cmd += f" --tokenizer {tokenizer_type} "
309
+
310
  print(cmd)
311
 
312
  try:
 
343
  return "train stop", gr.update(interactive=True), gr.update(interactive=False)
344
 
345
 
346
+ def get_list_projects():
347
+ project_list = []
348
+ for folder in os.listdir("data"):
349
+ path_folder = os.path.join("data", folder)
350
+ if not os.path.isdir(path_folder):
351
+ continue
352
+ folder = folder.lower()
353
+ if folder == "emilia_zh_en_pinyin":
354
+ continue
355
+ project_list.append(folder)
356
+
357
+ projects_selelect = None if not project_list else project_list[-1]
358
+
359
+ return project_list, projects_selelect
360
+
361
+
362
+ def create_data_project(name, tokenizer_type):
363
+ name += "_" + tokenizer_type
364
  os.makedirs(os.path.join(path_data, name), exist_ok=True)
365
  os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
366
+ project_list, projects_selelect = get_list_projects()
367
+ return gr.update(choices=project_list, value=name)
368
 
369
 
370
  def transcribe(file_audio, language="english"):
 
389
 
390
 
391
  def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
 
392
  path_project = os.path.join(path_data, name_project)
393
  path_dataset = os.path.join(path_project, "dataset")
394
  path_project_wavs = os.path.join(path_project, "wavs")
395
  file_metadata = os.path.join(path_project, "metadata.csv")
396
 
397
+ if not user:
398
+ if audio_files is None:
399
+ return "You need to load an audio file."
400
 
401
  if os.path.isdir(path_project_wavs):
402
  shutil.rmtree(path_project_wavs)
 
448
  except: # noqa: E722
449
  error_num += 1
450
 
451
+ with open(file_metadata, "w", encoding="utf-8-sig") as f:
452
  f.write(data)
453
 
454
  if error_num != []:
 
467
 
468
 
469
  def create_metadata(name_project, progress=gr.Progress()):
 
470
  path_project = os.path.join(path_data, name_project)
471
  path_project_wavs = os.path.join(path_project, "wavs")
472
  file_metadata = os.path.join(path_project, "metadata.csv")
 
477
  if not os.path.isfile(file_metadata):
478
  return "The file was not found in " + file_metadata
479
 
480
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
481
  data = f.read()
482
 
483
  audio_path_list = []
 
528
  for line in progress.tqdm(result, total=len(result), desc="prepare data"):
529
  writer.write(line)
530
 
531
+ with open(file_duration, "w") as f:
532
  json.dump({"duration": duration_list}, f, ensure_ascii=False)
533
 
534
  file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
 
558
  last_per_steps,
559
  finetune,
560
  ):
 
561
  path_project = os.path.join(path_data, name_project)
562
  file_duraction = os.path.join(path_project, "duration.json")
563
 
 
576
  data = json.load(file)
577
 
578
  duration_list = data["duration"]
 
579
  samples = len(duration_list)
580
+ hours = sum(duration_list) / 3600
581
 
582
  if torch.cuda.is_available():
583
  gpu_properties = torch.cuda.get_device_properties(0)
 
611
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
612
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
613
 
614
+ total_hours = hours
615
+ mel_hop_length = 256
616
+ mel_sampling_rate = 24000
617
+
618
+ # target
619
+ wanted_max_updates = 1000000
620
+
621
+ # train params
622
+ gpus = 1
623
+ frames_per_gpu = batch_size_per_gpu # 8 * 38400 = 307200
624
+ grad_accum = 1
625
+
626
+ # intermediate
627
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
628
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
629
+ updates_per_epoch = total_hours / mini_batch_hours
630
+ # steps_per_epoch = updates_per_epoch * grad_accum
631
+ epochs = wanted_max_updates / updates_per_epoch
632
+
633
  if finetune:
634
  learning_rate = 1e-5
635
  else:
636
  learning_rate = 7.5e-5
637
 
638
+ return (
639
+ batch_size_per_gpu,
640
+ max_samples,
641
+ num_warmup_updates,
642
+ save_per_updates,
643
+ last_per_steps,
644
+ samples,
645
+ learning_rate,
646
+ int(epochs),
647
+ )
648
 
649
 
650
+ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, safetensors: bool) -> str:
651
  try:
652
  checkpoint = torch.load(checkpoint_path)
653
  print("Original Checkpoint Keys:", checkpoint.keys())
654
 
655
  ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
656
+ if ema_model_state_dict is None:
657
+ return "No 'ema_model_state_dict' found in the checkpoint."
658
 
659
+ if safetensors:
660
+ new_checkpoint_path = new_checkpoint_path.replace(".pt", ".safetensors")
661
+ save_file(ema_model_state_dict, new_checkpoint_path)
662
+ else:
663
+ new_checkpoint_path = new_checkpoint_path.replace(".safetensors", ".pt")
664
  new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
665
  torch.save(new_checkpoint, new_checkpoint_path)
666
+
667
+ return f"New checkpoint saved at: {new_checkpoint_path}"
 
668
 
669
  except Exception as e:
670
  return f"An error occurred: {e}"
671
 
672
 
673
  def vocab_check(project_name):
674
+ name_project = project_name
675
  path_project = os.path.join(path_data, name_project)
676
 
677
  file_metadata = os.path.join(path_project, "metadata.csv")
 
680
  if not os.path.isfile(file_vocab):
681
  return f"the file {file_vocab} not found !"
682
 
683
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
684
  data = f.read()
685
+ vocab = data.split("\n")
686
+ vocab = set(vocab)
687
 
688
  if not os.path.isfile(file_metadata):
689
  return f"the file {file_metadata} not found !"
690
 
691
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
692
  data = f.read()
693
 
694
  miss_symbols = []
 
713
 
714
 
715
  def get_random_sample_prepare(project_name):
716
+ name_project = project_name
717
  path_project = os.path.join(path_data, name_project)
718
  file_arrow = os.path.join(path_project, "raw.arrow")
719
  if not os.path.isfile(file_arrow):
 
726
 
727
 
728
  def get_random_sample_transcribe(project_name):
729
+ name_project = project_name
730
  path_project = os.path.join(path_data, name_project)
731
  file_metadata = os.path.join(path_project, "metadata.csv")
732
  if not os.path.isfile(file_metadata):
733
  return "", None
734
 
735
  data = ""
736
+ with open(file_metadata, "r", encoding="utf-8-sig") as f:
737
  data = f.read()
738
 
739
  list_data = []
 
764
  global last_checkpoint, last_device, tts_api
765
 
766
  if not os.path.isfile(file_checkpoint):
767
+ return None, "checkpoint not found!"
768
 
769
  if training_process is not None:
770
  device_test = "cpu"
771
  else:
772
  device_test = None
773
 
774
+ device_test = "cpu"
775
  if last_checkpoint != file_checkpoint or last_device != device_test:
776
  if last_checkpoint != file_checkpoint:
777
  last_checkpoint = file_checkpoint
 
784
 
785
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
786
  tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
787
+ return f.name, tts_api.device
788
+
789
+
790
+ def check_finetune(finetune):
791
+ return gr.update(interactive=finetune), gr.update(interactive=finetune), gr.update(interactive=finetune)
792
+
793
+
794
+ def get_checkpoints_project(project_name, is_gradio=True):
795
+ if project_name is None:
796
+ return [], ""
797
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
798
+ path_project_ckpts = os.path.join("ckpts", project_name)
799
+
800
+ if os.path.isdir(path_project_ckpts):
801
+ files_checkpoints = glob(os.path.join(path_project_ckpts, "*.pt"))
802
+ files_checkpoints = sorted(
803
+ files_checkpoints,
804
+ key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
805
+ if os.path.basename(x) != "model_last.pt"
806
+ else float("inf"),
807
+ )
808
+ else:
809
+ files_checkpoints = []
810
+
811
+ selelect_checkpoint = None if not files_checkpoints else files_checkpoints[0]
812
+
813
+ if is_gradio:
814
+ return gr.update(choices=files_checkpoints, value=selelect_checkpoint)
815
+
816
+ return files_checkpoints, selelect_checkpoint
817
 
818
 
819
  with gr.Blocks() as app:
820
+ gr.Markdown(
821
+ """
822
+ # E2/F5 TTS AUTOMATIC FINETUNE
823
+
824
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
825
+
826
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
827
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
828
+
829
+ The checkpoints support English and Chinese.
830
+
831
+ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussions/143)
832
+ """
833
+ )
834
+
835
  with gr.Row():
836
+ projects, projects_selelect = get_list_projects()
837
+ tokenizer_type = gr.Radio(label="Tokenizer Type", choices=["pinyin", "char"], value="pinyin")
838
  project_name = gr.Textbox(label="project name", value="my_speak")
839
  bt_create = gr.Button("create new project")
840
 
841
+ cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
842
+
843
+ bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
844
 
845
  with gr.Tabs():
846
  with gr.TabItem("transcribe Data"):
847
+ ch_manual = gr.Checkbox(label="audio from path", value=False)
848
 
849
  mark_info_transcribe = gr.Markdown(
850
  """```plaintext
 
866
  txt_info_transcribe = gr.Text(label="info", value="")
867
  bt_transcribe.click(
868
  fn=transcribe_all,
869
+ inputs=[cm_project, audio_speaker, txt_lang, ch_manual],
870
  outputs=[txt_info_transcribe],
871
  )
872
  ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
 
879
 
880
  random_sample_transcribe.click(
881
  fn=get_random_sample_transcribe,
882
+ inputs=[cm_project],
883
  outputs=[random_text_transcribe, random_audio_transcribe],
884
  )
885
 
 
907
 
908
  bt_prepare = bt_create = gr.Button("prepare")
909
  txt_info_prepare = gr.Text(label="info", value="")
910
+ bt_prepare.click(fn=create_metadata, inputs=[cm_project], outputs=[txt_info_prepare])
911
 
912
  random_sample_prepare = gr.Button("random sample")
913
 
 
916
  random_audio_prepare = gr.Audio(label="Audio", type="filepath")
917
 
918
  random_sample_prepare.click(
919
+ fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
920
  )
921
 
922
  with gr.TabItem("train Data"):
923
  with gr.Row():
924
  bt_calculate = bt_create = gr.Button("Auto Settings")
 
925
  lb_samples = gr.Label(label="samples")
926
  batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
927
 
928
+ with gr.Row():
929
+ ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
930
+ tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
931
+ file_checkpoint_train = gr.Textbox(label="Checkpoint", value="")
932
+
933
  with gr.Row():
934
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
935
  learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
 
958
  start_button.click(
959
  fn=start_training,
960
  inputs=[
961
+ cm_project,
962
  exp_name,
963
  learning_rate,
964
  batch_size_per_gpu,
 
971
  save_per_updates,
972
  last_per_steps,
973
  ch_finetune,
974
+ file_checkpoint_train,
975
+ tokenizer_type,
976
+ tokenizer_file,
977
  ],
978
  outputs=[txt_info_train, start_button, stop_button],
979
  )
980
  stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
981
+
982
  bt_calculate.click(
983
  fn=calculate_train,
984
  inputs=[
985
+ cm_project,
986
  batch_size_type,
987
  max_samples,
988
  learning_rate,
 
999
  last_per_steps,
1000
  lb_samples,
1001
  learning_rate,
1002
+ epochs,
1003
  ],
1004
  )
1005
 
1006
+ ch_finetune.change(
1007
+ check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1008
+ )
1009
+
1010
  with gr.TabItem("reduse checkpoint"):
1011
  txt_path_checkpoint = gr.Text(label="path checkpoint :")
1012
  txt_path_checkpoint_small = gr.Text(label="path output :")
1013
+ ch_safetensors = gr.Checkbox(label="safetensors", value="")
1014
  txt_info_reduse = gr.Text(label="info", value="")
1015
  reduse_button = gr.Button("reduse")
1016
  reduse_button.click(
1017
  fn=extract_and_save_ema_model,
1018
+ inputs=[txt_path_checkpoint, txt_path_checkpoint_small, ch_safetensors],
1019
  outputs=[txt_info_reduse],
1020
  )
1021
 
1022
  with gr.TabItem("vocab check experiment"):
1023
  check_button = gr.Button("check vocab")
1024
  txt_info_check = gr.Text(label="info", value="")
1025
+ check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
1026
 
1027
  with gr.TabItem("test model"):
1028
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1029
+ list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1030
+
1031
  nfe_step = gr.Number(label="n_step", value=32)
1032
+
1033
+ with gr.Row():
1034
+ cm_checkpoint = gr.Dropdown(
1035
+ choices=list_checkpoints, value=checkpoint_select, label="checkpoints", allow_custom_value=True
1036
+ )
1037
+ bt_checkpoint_refresh = gr.Button("refresh")
1038
 
1039
  random_sample_infer = gr.Button("random sample")
1040
 
 
1042
  ref_audio = gr.Audio(label="audio ref", type="filepath")
1043
  gen_text = gr.Textbox(label="gen text")
1044
  random_sample_infer.click(
1045
+ fn=get_random_sample_infer, inputs=[cm_project], outputs=[ref_text, gen_text, ref_audio]
1046
  )
1047
+
1048
+ with gr.Row():
1049
+ txt_info_gpu = gr.Textbox("", label="device")
1050
+ check_button_infer = gr.Button("infer")
1051
+
1052
  gen_audio = gr.Audio(label="audio gen", type="filepath")
1053
 
1054
  check_button_infer.click(
1055
  fn=infer,
1056
+ inputs=[cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step],
1057
+ outputs=[gen_audio, txt_info_gpu],
1058
  )
1059
 
1060
+ bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1061
+ cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
1062
+
1063
 
1064
  @click.command()
1065
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")