unknown commited on
Commit
1bbec4a
·
1 Parent(s): 2a844ae

add 8bit and fix some value

Browse files
src/f5_tts/train/finetune_cli.py CHANGED
@@ -13,6 +13,9 @@ from importlib.resources import files
13
  target_sample_rate = 24000
14
  n_mel_channels = 100
15
  hop_length = 256
 
 
 
16
 
17
 
18
  # -------------------------- Argument Parsing --------------------------- #
@@ -40,7 +43,7 @@ def parse_args():
40
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
41
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
42
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
43
- parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
44
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
45
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
46
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
@@ -121,11 +124,15 @@ def main():
121
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
122
 
123
  print("\nvocab : ", vocab_size)
 
124
 
125
  mel_spec_kwargs = dict(
126
- target_sample_rate=target_sample_rate,
127
- n_mel_channels=n_mel_channels,
128
  hop_length=hop_length,
 
 
 
 
129
  )
130
 
131
  model = CFM(
 
13
  target_sample_rate = 24000
14
  n_mel_channels = 100
15
  hop_length = 256
16
+ win_length = 1024
17
+ n_fft = 1024
18
+ mel_spec_type = "vocos" # 'vocos' or 'bigvgan'
19
 
20
 
21
  # -------------------------- Argument Parsing --------------------------- #
 
43
  parser.add_argument("--max_samples", type=int, default=64, help="Max sequences per batch")
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
+ parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
  parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
 
124
  vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
125
 
126
  print("\nvocab : ", vocab_size)
127
+ print("\nvocoder : ", mel_spec_type)
128
 
129
  mel_spec_kwargs = dict(
130
+ n_fft=n_fft,
 
131
  hop_length=hop_length,
132
+ win_length=win_length,
133
+ n_mel_channels=n_mel_channels,
134
+ target_sample_rate=target_sample_rate,
135
+ mel_spec_type=mel_spec_type,
136
  )
137
 
138
  model = CFM(
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -43,6 +43,13 @@ last_ema = None
43
 
44
  path_data = str(files("f5_tts").joinpath("../../data"))
45
  path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
 
 
 
 
 
 
 
46
  file_train = "src/f5_tts/train/finetune_cli.py"
47
 
48
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
@@ -70,6 +77,7 @@ def save_settings(
70
  tokenizer_file,
71
  mixed_precision,
72
  logger,
 
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
@@ -93,6 +101,7 @@ def save_settings(
93
  "tokenizer_file": tokenizer_file,
94
  "mixed_precision": mixed_precision,
95
  "logger": logger,
 
96
  }
97
  with open(file_setting, "w") as f:
98
  json.dump(settings, f, indent=4)
@@ -124,6 +133,7 @@ def load_settings(project_name):
124
  "tokenizer_file": "",
125
  "mixed_precision": "none",
126
  "logger": "wandb",
 
127
  }
128
  return (
129
  settings["exp_name"],
@@ -143,12 +153,15 @@ def load_settings(project_name):
143
  settings["tokenizer_file"],
144
  settings["mixed_precision"],
145
  settings["logger"],
 
146
  )
147
 
148
  with open(file_setting, "r") as f:
149
  settings = json.load(f)
150
  if "logger" not in settings:
151
  settings["logger"] = "wandb"
 
 
152
  return (
153
  settings["exp_name"],
154
  settings["learning_rate"],
@@ -167,6 +180,7 @@ def load_settings(project_name):
167
  settings["tokenizer_file"],
168
  settings["mixed_precision"],
169
  settings["logger"],
 
170
  )
171
 
172
 
@@ -381,6 +395,7 @@ def start_training(
381
  mixed_precision="fp16",
382
  stream=False,
383
  logger="wandb",
 
384
  ):
385
  global training_process, tts_api, stop_signal, pipe
386
 
@@ -447,11 +462,10 @@ def start_training(
447
  f"--dataset_name {dataset_name}"
448
  )
449
 
450
- if finetune:
451
- cmd += f" --finetune {finetune}"
452
 
453
  if file_checkpoint_train != "":
454
- cmd += f" --file_checkpoint_train {file_checkpoint_train}"
455
 
456
  if tokenizer_file != "":
457
  cmd += f" --tokenizer_path {tokenizer_file}"
@@ -460,7 +474,10 @@ def start_training(
460
 
461
  cmd += f" --log_samples True --logger {logger} "
462
 
463
- print(cmd)
 
 
 
464
 
465
  save_settings(
466
  dataset_name,
@@ -481,6 +498,7 @@ def start_training(
481
  tokenizer_file,
482
  mixed_precision,
483
  logger,
 
484
  )
485
 
486
  try:
@@ -758,11 +776,9 @@ def get_correct_audio_path(
758
  # Case 2: If it has a supported extension but is not a full path
759
  elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
760
  file_audio = os.path.join(base_path, audio_input)
761
- print("2")
762
 
763
  # Case 3: If only the name is given (no extension and not a full path)
764
  elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
765
- print("3")
766
  for ext in supported_formats:
767
  potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
768
  if os.path.exists(potential_file):
@@ -773,6 +789,18 @@ def get_correct_audio_path(
773
  return file_audio
774
 
775
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
777
  path_project = os.path.join(path_data, name_project)
778
  path_project_wavs = os.path.join(path_project, "wavs")
@@ -816,9 +844,12 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
816
  continue
817
 
818
  if duration < 1 or duration > 25:
819
- error_files.append([file_audio, "duration < 1 or > 25 "])
 
 
 
820
  continue
821
- if len(text) < 4:
822
  error_files.append([file_audio, "very small text len 3"])
823
  continue
824
 
@@ -1208,7 +1239,9 @@ def get_random_sample_infer(project_name):
1208
  )
1209
 
1210
 
1211
- def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema):
 
 
1212
  global last_checkpoint, last_device, tts_api, last_ema
1213
 
1214
  if not os.path.isfile(file_checkpoint):
@@ -1238,8 +1271,17 @@ def infer(project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe
1238
  print("update >> ", device_test, file_checkpoint, use_ema)
1239
 
1240
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1241
- tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
1242
- return f.name, tts_api.device
 
 
 
 
 
 
 
 
 
1243
 
1244
 
1245
  def check_finetune(finetune):
@@ -1506,6 +1548,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1506
  ```"""
1507
  )
1508
  ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
 
1509
  bt_prepare = bt_create = gr.Button("Prepare")
1510
  txt_info_prepare = gr.Text(label="Info", value="")
1511
  txt_vocab_prepare = gr.Text(label="Vocab", value="")
@@ -1560,6 +1603,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1560
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1561
 
1562
  with gr.Row():
 
1563
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1564
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1565
  start_button = gr.Button("Start Training")
@@ -1584,6 +1628,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1584
  tokenizer_filev,
1585
  mixed_precisionv,
1586
  cd_loggerv,
 
1587
  ) = load_settings(projects_selelect)
1588
  exp_name.value = exp_namev
1589
  learning_rate.value = learning_ratev
@@ -1602,6 +1647,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1602
  tokenizer_file.value = tokenizer_filev
1603
  mixed_precision.value = mixed_precisionv
1604
  cd_logger.value = cd_loggerv
 
1605
 
1606
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1607
  txt_info_train = gr.Text(label="Info", value="")
@@ -1660,6 +1706,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1660
  mixed_precision,
1661
  ch_stream,
1662
  cd_logger,
 
1663
  ],
1664
  outputs=[txt_info_train, start_button, stop_button],
1665
  )
@@ -1732,12 +1779,17 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1732
 
1733
  with gr.TabItem("Test Model"):
1734
  gr.Markdown("""```plaintext
1735
- SOS: Check the use_ema setting (True or False) for your model to see what works best for you.
1736
  ```""")
1737
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1738
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1739
 
1740
- nfe_step = gr.Number(label="NFE Step", value=32)
 
 
 
 
 
1741
  ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
1742
  with gr.Row():
1743
  cm_checkpoint = gr.Dropdown(
@@ -1757,14 +1809,27 @@ SOS: Check the use_ema setting (True or False) for your model to see what works
1757
 
1758
  with gr.Row():
1759
  txt_info_gpu = gr.Textbox("", label="Device")
 
1760
  check_button_infer = gr.Button("Infer")
1761
 
1762
  gen_audio = gr.Audio(label="Audio Gen", type="filepath")
1763
 
1764
  check_button_infer.click(
1765
  fn=infer,
1766
- inputs=[cm_project, cm_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, ch_use_ema],
1767
- outputs=[gen_audio, txt_info_gpu],
 
 
 
 
 
 
 
 
 
 
 
 
1768
  )
1769
 
1770
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
 
43
 
44
  path_data = str(files("f5_tts").joinpath("../../data"))
45
  path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
46
+
47
+ from pathlib import Path
48
+
49
+ base_path = Path(__file__).resolve().parent.parent.parent.parent
50
+ path_data = str(base_path / "data")
51
+ path_project_ckpts = str(base_path / "ckpts")
52
+
53
  file_train = "src/f5_tts/train/finetune_cli.py"
54
 
55
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
77
  tokenizer_file,
78
  mixed_precision,
79
  logger,
80
+ ch_8bit_adam,
81
  ):
82
  path_project = os.path.join(path_project_ckpts, project_name)
83
  os.makedirs(path_project, exist_ok=True)
 
101
  "tokenizer_file": tokenizer_file,
102
  "mixed_precision": mixed_precision,
103
  "logger": logger,
104
+ "bnb_optimizer": ch_8bit_adam,
105
  }
106
  with open(file_setting, "w") as f:
107
  json.dump(settings, f, indent=4)
 
133
  "tokenizer_file": "",
134
  "mixed_precision": "none",
135
  "logger": "wandb",
136
+ "bnb_optimizer": False,
137
  }
138
  return (
139
  settings["exp_name"],
 
153
  settings["tokenizer_file"],
154
  settings["mixed_precision"],
155
  settings["logger"],
156
+ settings["bnb_optimizer"],
157
  )
158
 
159
  with open(file_setting, "r") as f:
160
  settings = json.load(f)
161
  if "logger" not in settings:
162
  settings["logger"] = "wandb"
163
+ if "bnb_optimizer" not in settings:
164
+ settings["bnb_optimizer"] = False
165
  return (
166
  settings["exp_name"],
167
  settings["learning_rate"],
 
180
  settings["tokenizer_file"],
181
  settings["mixed_precision"],
182
  settings["logger"],
183
+ settings["bnb_optimizer"],
184
  )
185
 
186
 
 
395
  mixed_precision="fp16",
396
  stream=False,
397
  logger="wandb",
398
+ ch_8bit_adam=False,
399
  ):
400
  global training_process, tts_api, stop_signal, pipe
401
 
 
462
  f"--dataset_name {dataset_name}"
463
  )
464
 
465
+ cmd += f" --finetune {finetune}"
 
466
 
467
  if file_checkpoint_train != "":
468
+ cmd += f" --pretrain {file_checkpoint_train}"
469
 
470
  if tokenizer_file != "":
471
  cmd += f" --tokenizer_path {tokenizer_file}"
 
474
 
475
  cmd += f" --log_samples True --logger {logger} "
476
 
477
+ if ch_8bit_adam:
478
+ cmd += " --bnb_optimizer True "
479
+
480
+ print("run command : \n" + cmd + "\n")
481
 
482
  save_settings(
483
  dataset_name,
 
498
  tokenizer_file,
499
  mixed_precision,
500
  logger,
501
+ ch_8bit_adam,
502
  )
503
 
504
  try:
 
776
  # Case 2: If it has a supported extension but is not a full path
777
  elif has_supported_extension(audio_input) and not os.path.isabs(audio_input):
778
  file_audio = os.path.join(base_path, audio_input)
 
779
 
780
  # Case 3: If only the name is given (no extension and not a full path)
781
  elif not has_supported_extension(audio_input) and not os.path.isabs(audio_input):
 
782
  for ext in supported_formats:
783
  potential_file = os.path.join(base_path, f"{audio_input}.{ext}")
784
  if os.path.exists(potential_file):
 
789
  return file_audio
790
 
791
 
792
+ def get_nested_value(data, format):
793
+ keys = format.split("/")
794
+
795
+ item = data
796
+ for key in keys:
797
+ item = item.get(key)
798
+ if item is None:
799
+ return None
800
+
801
+ return item
802
+
803
+
804
  def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
805
  path_project = os.path.join(path_data, name_project)
806
  path_project_wavs = os.path.join(path_project, "wavs")
 
844
  continue
845
 
846
  if duration < 1 or duration > 25:
847
+ if duration > 25:
848
+ error_files.append([file_audio, "duration > 25 sec"])
849
+ if duration < 1:
850
+ error_files.append([file_audio, "duration < 1 sec "])
851
  continue
852
+ if len(text) < 3:
853
  error_files.append([file_audio, "very small text len 3"])
854
  continue
855
 
 
1239
  )
1240
 
1241
 
1242
+ def infer(
1243
+ project, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step, use_ema, speed, seed, remove_silence
1244
+ ):
1245
  global last_checkpoint, last_device, tts_api, last_ema
1246
 
1247
  if not os.path.isfile(file_checkpoint):
 
1271
  print("update >> ", device_test, file_checkpoint, use_ema)
1272
 
1273
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
1274
+ tts_api.infer(
1275
+ gen_text=gen_text.lower().strip(),
1276
+ ref_text=ref_text.lower().strip(),
1277
+ ref_file=ref_audio,
1278
+ nfe_step=nfe_step,
1279
+ file_wave=f.name,
1280
+ speed=speed,
1281
+ seed=seed,
1282
+ remove_silence=remove_silence,
1283
+ )
1284
+ return f.name, tts_api.device, str(tts_api.seed)
1285
 
1286
 
1287
  def check_finetune(finetune):
 
1548
  ```"""
1549
  )
1550
  ch_tokenizern = gr.Checkbox(label="Create Vocabulary", value=False, visible=False)
1551
+
1552
  bt_prepare = bt_create = gr.Button("Prepare")
1553
  txt_info_prepare = gr.Text(label="Info", value="")
1554
  txt_vocab_prepare = gr.Text(label="Vocab", value="")
 
1603
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1604
 
1605
  with gr.Row():
1606
+ ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
1607
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1608
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1609
  start_button = gr.Button("Start Training")
 
1628
  tokenizer_filev,
1629
  mixed_precisionv,
1630
  cd_loggerv,
1631
+ ch_8bit_adamv,
1632
  ) = load_settings(projects_selelect)
1633
  exp_name.value = exp_namev
1634
  learning_rate.value = learning_ratev
 
1647
  tokenizer_file.value = tokenizer_filev
1648
  mixed_precision.value = mixed_precisionv
1649
  cd_logger.value = cd_loggerv
1650
+ ch_8bit_adam.value = ch_8bit_adamv
1651
 
1652
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1653
  txt_info_train = gr.Text(label="Info", value="")
 
1706
  mixed_precision,
1707
  ch_stream,
1708
  cd_logger,
1709
+ ch_8bit_adam,
1710
  ],
1711
  outputs=[txt_info_train, start_button, stop_button],
1712
  )
 
1779
 
1780
  with gr.TabItem("Test Model"):
1781
  gr.Markdown("""```plaintext
1782
+ SOS: Check the use_ema setting (True or False) for your model to see what works best for you. use seed -1 from random
1783
  ```""")
1784
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1785
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
1786
 
1787
+ with gr.Row():
1788
+ nfe_step = gr.Number(label="NFE Step", value=32)
1789
+ speed = gr.Slider(label="Speed", value=1.0, minimum=0.3, maximum=2.0, step=0.1)
1790
+ seed = gr.Number(label="Seed", value=-1, minimum=-1)
1791
+ remove_silence = gr.Checkbox(label="Remove Silence")
1792
+
1793
  ch_use_ema = gr.Checkbox(label="Use EMA", value=True)
1794
  with gr.Row():
1795
  cm_checkpoint = gr.Dropdown(
 
1809
 
1810
  with gr.Row():
1811
  txt_info_gpu = gr.Textbox("", label="Device")
1812
+ seed_info = gr.Text(label="Seed :")
1813
  check_button_infer = gr.Button("Infer")
1814
 
1815
  gen_audio = gr.Audio(label="Audio Gen", type="filepath")
1816
 
1817
  check_button_infer.click(
1818
  fn=infer,
1819
+ inputs=[
1820
+ cm_project,
1821
+ cm_checkpoint,
1822
+ exp_name,
1823
+ ref_text,
1824
+ ref_audio,
1825
+ gen_text,
1826
+ nfe_step,
1827
+ ch_use_ema,
1828
+ speed,
1829
+ seed,
1830
+ remove_silence,
1831
+ ],
1832
+ outputs=[gen_audio, txt_info_gpu, seed_info],
1833
  )
1834
 
1835
  bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])