unknown commited on
Commit
37eb3b5
·
1 Parent(s): d601a70

add tensorboard and add export sample for mel and audio

Browse files
src/f5_tts/model/trainer.py CHANGED
@@ -3,7 +3,11 @@ from __future__ import annotations
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
- import wandb
 
 
 
 
7
 
8
  import torch
9
  from torch.optim import AdamW
@@ -19,9 +23,26 @@ from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
 
22
-
 
23
  # trainer
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class Trainer:
27
  def __init__(
@@ -39,6 +60,8 @@ class Trainer:
39
  max_grad_norm=1.0,
40
  noise_scheduler: str | None = None,
41
  duration_predictor: torch.nn.Module | None = None,
 
 
42
  wandb_project="test_e2-tts",
43
  wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
@@ -46,24 +69,24 @@ class Trainer:
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
 
49
  ):
50
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
51
 
52
- logger = "wandb" if wandb.api.api_key else None
53
- print(f"Using logger: {logger}")
54
-
55
- self.accelerator = Accelerator(
56
- log_with=logger,
57
- kwargs_handlers=[ddp_kwargs],
58
- gradient_accumulation_steps=grad_accumulation_steps,
59
- **accelerate_kwargs,
60
- )
61
 
62
- if logger == "wandb":
63
  if exists(wandb_resume_id):
64
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
65
  else:
66
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
 
67
  self.accelerator.init_trackers(
68
  project_name=wandb_project,
69
  init_kwargs=init_kwargs,
@@ -80,12 +103,37 @@ class Trainer:
80
  "noise_scheduler": noise_scheduler,
81
  },
82
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  self.model = model
85
 
86
  if self.is_main:
87
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
88
-
89
  self.ema_model.to(self.accelerator.device)
90
 
91
  self.epochs = epochs
@@ -175,6 +223,82 @@ class Trainer:
175
  gc.collect()
176
  return step
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
179
  if exists(resumable_with_seed):
180
  generator = torch.Generator()
@@ -270,6 +394,15 @@ class Trainer:
270
  loss, cond, pred = self.model(
271
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
272
  )
 
 
 
 
 
 
 
 
 
273
  self.accelerator.backward(loss)
274
 
275
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -285,7 +418,7 @@ class Trainer:
285
  global_step += 1
286
 
287
  if self.accelerator.is_local_main_process:
288
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
289
 
290
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
291
 
 
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
+
7
+ try:
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ except ImportError:
10
+ print("TensorBoard is not installed")
11
 
12
  import torch
13
  from torch.optim import AdamW
 
23
  from f5_tts.model.utils import exists, default
24
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
25
 
26
+ import numpy as np
27
+ import matplotlib.pyplot as plt
28
  # trainer
29
 
30
+ # audio imports
31
+ import torchaudio
32
+ import soundfile as sf
33
+ from vocos import Vocos
34
+ import warnings
35
+
36
+ warnings.filterwarnings("ignore", category=FutureWarning)
37
+
38
+ # -----------------------------------------
39
+ target_sample_rate = 24000
40
+ hop_length = 256
41
+ nfe_step = 16
42
+ cfg_strength = 2.0
43
+ sway_sampling_coef = -1.0
44
+ # -----------------------------------------
45
+
46
 
47
  class Trainer:
48
  def __init__(
 
60
  max_grad_norm=1.0,
61
  noise_scheduler: str | None = None,
62
  duration_predictor: torch.nn.Module | None = None,
63
+ logger: str = "wandb", # Add logger parameter wandb,tensorboard , none
64
+ log_dir: str = "logs", # Add log directory parameter
65
  wandb_project="test_e2-tts",
66
  wandb_run_name="test_run",
67
  wandb_resume_id: str = None,
 
69
  accelerate_kwargs: dict = dict(),
70
  ema_kwargs: dict = dict(),
71
  bnb_optimizer: bool = False,
72
+ export_samples=False,
73
  ):
74
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
75
 
76
+ self.logger = logger
77
+ if self.logger == "wandb":
78
+ self.accelerator = Accelerator(
79
+ log_with="wandb",
80
+ kwargs_handlers=[ddp_kwargs],
81
+ gradient_accumulation_steps=grad_accumulation_steps,
82
+ **accelerate_kwargs,
83
+ )
 
84
 
 
85
  if exists(wandb_resume_id):
86
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
87
  else:
88
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
89
+
90
  self.accelerator.init_trackers(
91
  project_name=wandb_project,
92
  init_kwargs=init_kwargs,
 
103
  "noise_scheduler": noise_scheduler,
104
  },
105
  )
106
+ elif self.logger == "tensorboard":
107
+ self.accelerator = Accelerator(
108
+ kwargs_handlers=[ddp_kwargs],
109
+ gradient_accumulation_steps=grad_accumulation_steps,
110
+ **accelerate_kwargs,
111
+ )
112
+ if self.is_main:
113
+ path_log_dir = os.path.join(log_dir, wandb_project)
114
+ os.makedirs(path_log_dir, exist_ok=True)
115
+ existing_folders = [folder for folder in os.listdir(path_log_dir) if folder.startswith("exp")]
116
+ next_number = len(existing_folders) + 2
117
+ folder_name = f"exp{next_number}"
118
+ folder_path = os.path.join(path_log_dir, folder_name)
119
+ os.makedirs(folder_path, exist_ok=True)
120
+
121
+ self.writer = SummaryWriter(log_dir=folder_path)
122
+
123
+ # export audio and mel
124
+ self.export_samples = export_samples
125
+ if self.export_samples:
126
+ self.path_ckpts_project = checkpoint_path
127
+
128
+ self.vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
129
+ self.vocos.to("cpu")
130
+ self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
131
+ os.makedirs(self.file_path_samples, exist_ok=True)
132
 
133
  self.model = model
134
 
135
  if self.is_main:
136
  self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)
 
137
  self.ema_model.to(self.accelerator.device)
138
 
139
  self.epochs = epochs
 
223
  gc.collect()
224
  return step
225
 
226
+ def log(self, metrics, step):
227
+ """Unified logging method for both WandB and TensorBoard"""
228
+ if self.logger == "none":
229
+ return
230
+ if self.logger == "wandb":
231
+ self.accelerator.log(metrics, step=step)
232
+ elif self.is_main:
233
+ for key, value in metrics.items():
234
+ self.writer.add_scalar(key, value, step)
235
+
236
+ def export_add_log(self, global_step, mel_org, text_inputs):
237
+ try:
238
+ generated_wave_org = self.vocos.decode(mel_org.unsqueeze(0).cpu())
239
+ generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
240
+ file_wav_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.wav")
241
+ sf.write(file_wav_org, generated_wave_org, target_sample_rate)
242
+
243
+ audio, sr = torchaudio.load(file_wav_org)
244
+ audio = audio.to("cuda")
245
+
246
+ ref_audio_len = audio.shape[-1] // hop_length
247
+ text = [text_inputs[0] + [" . "] + text_inputs[0]]
248
+ duration = int((audio.shape[1] / 256) * 2.0)
249
+
250
+ with torch.inference_mode():
251
+ generated_gen, _ = self.model.sample(
252
+ cond=audio,
253
+ text=text,
254
+ duration=duration,
255
+ steps=nfe_step,
256
+ cfg_strength=cfg_strength,
257
+ sway_sampling_coef=sway_sampling_coef,
258
+ )
259
+
260
+ generated_gen = generated_gen.to(torch.float32)
261
+ generated_gen = generated_gen[:, ref_audio_len:, :]
262
+ generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
263
+ generated_wave_gen = self.vocos.decode(generated_mel_spec_gen.cpu())
264
+ generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
265
+ file_wav_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.wav")
266
+ sf.write(file_wav_gen, generated_wave_gen, target_sample_rate)
267
+
268
+ if self.logger == "tensorboard":
269
+ self.writer.add_audio("Audio/original", generated_wave_org, global_step, sample_rate=target_sample_rate)
270
+
271
+ self.writer.add_audio("Audio/generate", generated_wave_gen, global_step, sample_rate=target_sample_rate)
272
+
273
+ mel_org = mel_org
274
+ mel_min, mel_max = mel_org.min(), mel_org.max()
275
+ mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
276
+ mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
277
+ mel_colored = np.transpose(mel_colored, (2, 0, 1))
278
+
279
+ if self.logger == "tensorboard":
280
+ self.writer.add_image("Mel/oginal", mel_colored, global_step, dataformats="CHW")
281
+
282
+ mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0))
283
+ file_gen_org = os.path.join(self.file_path_samples, f"step_{global_step}_org.png")
284
+ plt.imsave(file_gen_org, mel_colored_hwc)
285
+
286
+ mel_gen = generated_mel_spec_gen[0]
287
+ mel_min, mel_max = mel_gen.min(), mel_gen.max()
288
+ mel_norm = (mel_gen - mel_min) / (mel_max - mel_min + 1e-8)
289
+ mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
290
+ mel_colored = np.transpose(mel_colored, (2, 0, 1))
291
+
292
+ if self.logger == "tensorboard":
293
+ self.writer.add_image("Mel/generate", mel_colored, global_step, dataformats="CHW")
294
+
295
+ mel_colored_hwc = np.transpose(mel_colored, (1, 2, 0))
296
+ file_gen_gen = os.path.join(self.file_path_samples, f"step_{global_step}_gen.png")
297
+ plt.imsave(file_gen_gen, mel_colored_hwc)
298
+
299
+ except Exception as e:
300
+ print("An error occurred:", e)
301
+
302
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
303
  if exists(resumable_with_seed):
304
  generator = torch.Generator()
 
394
  loss, cond, pred = self.model(
395
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
396
  )
397
+
398
+ # save 4 audio per save step
399
+ if (
400
+ self.accelerator.is_local_main_process
401
+ and self.export_samples
402
+ and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
403
+ ):
404
+ self.export_add_log(global_step, batch["mel"][0], text_inputs)
405
+
406
  self.accelerator.backward(loss)
407
 
408
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
418
  global_step += 1
419
 
420
  if self.accelerator.is_local_main_process:
421
+ self.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
422
 
423
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
424
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -56,6 +56,14 @@ def parse_args():
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
 
 
 
 
 
 
 
 
59
  return parser.parse_args()
60
 
61
 
@@ -64,6 +72,7 @@ def parse_args():
64
 
65
  def main():
66
  args = parse_args()
 
67
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
68
 
69
  # Model parameters based on experiment name
@@ -136,6 +145,8 @@ def main():
136
  wandb_run_name=args.exp_name,
137
  wandb_resume_id=wandb_resume_id,
138
  last_per_steps=args.last_per_steps,
 
 
139
  )
140
 
141
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
56
  help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
57
  )
58
 
59
+ parser.add_argument(
60
+ "--export_samples",
61
+ type=bool,
62
+ default=False,
63
+ help="Export 4 audio and spect samples for the checkpoint audio, per step.",
64
+ )
65
+ parser.add_argument("--logger", type=str, default="wandb", choices=["none", "wandb", "tensorboard"], help="logger")
66
+
67
  return parser.parse_args()
68
 
69
 
 
72
 
73
  def main():
74
  args = parse_args()
75
+
76
  checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
77
 
78
  # Model parameters based on experiment name
 
145
  wandb_run_name=args.exp_name,
146
  wandb_resume_id=wandb_resume_id,
147
  last_per_steps=args.last_per_steps,
148
+ logger=args.logger,
149
+ export_samples=args.export_samples,
150
  )
151
 
152
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -447,6 +447,8 @@ def start_training(
447
 
448
  cmd += f" --tokenizer {tokenizer_type} "
449
 
 
 
450
  print(cmd)
451
 
452
  save_settings(
@@ -1223,6 +1225,27 @@ def get_checkpoints_project(project_name, is_gradio=True):
1223
  return files_checkpoints, selelect_checkpoint
1224
 
1225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1226
  def get_gpu_stats():
1227
  gpu_stats = ""
1228
 
@@ -1290,6 +1313,21 @@ def get_combined_stats():
1290
  return combined_stats
1291
 
1292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1293
  with gr.Blocks() as app:
1294
  gr.Markdown(
1295
  """
@@ -1511,6 +1549,47 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1511
 
1512
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1513
  txt_info_train = gr.Text(label="info", value="")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1514
  start_button.click(
1515
  fn=start_training,
1516
  inputs=[
 
447
 
448
  cmd += f" --tokenizer {tokenizer_type} "
449
 
450
+ cmd += " --export_samples True --logger wandb "
451
+
452
  print(cmd)
453
 
454
  save_settings(
 
1225
  return files_checkpoints, selelect_checkpoint
1226
 
1227
 
1228
+ def get_audio_project(project_name, is_gradio=True):
1229
+ if project_name is None:
1230
+ return [], ""
1231
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
1232
+
1233
+ if os.path.isdir(path_project_ckpts):
1234
+ files_audios = glob(os.path.join(path_project_ckpts, project_name, "samples", "*.wav"))
1235
+ files_audios = sorted(files_audios, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]))
1236
+
1237
+ files_audios = [item.replace("_gen.wav", "") for item in files_audios if item.endswith("_gen.wav")]
1238
+ else:
1239
+ files_audios = []
1240
+
1241
+ selelect_checkpoint = None if not files_audios else files_audios[0]
1242
+
1243
+ if is_gradio:
1244
+ return gr.update(choices=files_audios, value=selelect_checkpoint)
1245
+
1246
+ return files_audios, selelect_checkpoint
1247
+
1248
+
1249
  def get_gpu_stats():
1250
  gpu_stats = ""
1251
 
 
1313
  return combined_stats
1314
 
1315
 
1316
+ def get_audio_select(file_sample):
1317
+ select_audio_org = file_sample
1318
+ select_audio_gen = file_sample
1319
+ select_image_org = file_sample
1320
+ select_image_gen = file_sample
1321
+
1322
+ if file_sample is not None:
1323
+ select_audio_org += "_org.wav"
1324
+ select_audio_gen += "_gen.wav"
1325
+ select_image_org += "_org.png"
1326
+ select_image_gen += "_gen.png"
1327
+
1328
+ return select_audio_org, select_audio_gen, select_image_org, select_image_gen
1329
+
1330
+
1331
  with gr.Blocks() as app:
1332
  gr.Markdown(
1333
  """
 
1549
 
1550
  ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1551
  txt_info_train = gr.Text(label="info", value="")
1552
+
1553
+ list_audios, select_audio = get_audio_project(projects_selelect, False)
1554
+
1555
+ select_audio_org = select_audio
1556
+ select_audio_gen = select_audio
1557
+ select_image_org = select_audio
1558
+ select_image_gen = select_audio
1559
+
1560
+ if select_audio is not None:
1561
+ select_audio_org += "_org.wav"
1562
+ select_audio_gen += "_gen.wav"
1563
+ select_image_org += "_org.png"
1564
+ select_image_gen += "_gen.png"
1565
+
1566
+ with gr.Row():
1567
+ ch_list_audio = gr.Dropdown(
1568
+ choices=list_audios,
1569
+ value=select_audio,
1570
+ label="audios",
1571
+ allow_custom_value=True,
1572
+ scale=6,
1573
+ interactive=True,
1574
+ )
1575
+ bt_stream_audio = gr.Button("refresh", scale=1)
1576
+ bt_stream_audio.click(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1577
+ cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1578
+
1579
+ with gr.Row():
1580
+ audio_org_stream = gr.Audio(label="original", type="filepath", value=select_audio_org)
1581
+ mel_org_stream = gr.Image(label="original", type="filepath", value=select_image_org)
1582
+
1583
+ with gr.Row():
1584
+ audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
1585
+ mel_gen_stream = gr.Image(label="generate", type="filepath", value=select_image_gen)
1586
+
1587
+ ch_list_audio.change(
1588
+ fn=get_audio_select,
1589
+ inputs=[ch_list_audio],
1590
+ outputs=[audio_org_stream, audio_gen_stream, mel_org_stream, mel_gen_stream],
1591
+ )
1592
+
1593
  start_button.click(
1594
  fn=start_training,
1595
  inputs=[