SWivid commited on
Commit
da1b409
·
1 Parent(s): 5b10099

basic structure

Browse files
src/f5_tts/model/trainer.py CHANGED
@@ -3,9 +3,10 @@ from __future__ import annotations
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
-
7
 
8
  import torch
 
9
  from torch.optim import AdamW
10
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
@@ -19,6 +20,7 @@ from f5_tts.model import CFM
19
  from f5_tts.model.utils import exists, default
20
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
21
 
 
22
  # trainer
23
 
24
 
@@ -38,33 +40,32 @@ class Trainer:
38
  max_grad_norm=1.0,
39
  noise_scheduler: str | None = None,
40
  duration_predictor: torch.nn.Module | None = None,
41
- logger: str = "wandb", # Add logger parameter wandb,tensorboard , none
42
- log_dir: str = "logs", # Add log directory parameter
43
  wandb_project="test_e2-tts",
44
  wandb_run_name="test_run",
45
  wandb_resume_id: str = None,
 
46
  last_per_steps=None,
47
  accelerate_kwargs: dict = dict(),
48
  ema_kwargs: dict = dict(),
49
  bnb_optimizer: bool = False,
50
- export_samples=False,
51
  ):
52
- # export audio and mel
53
- self.export_samples = export_samples
54
- if export_samples:
55
- self.path_ckpts_project = checkpoint_path
56
-
57
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  self.logger = logger
60
  if self.logger == "wandb":
61
- self.accelerator = Accelerator(
62
- log_with="wandb",
63
- kwargs_handlers=[ddp_kwargs],
64
- gradient_accumulation_steps=grad_accumulation_steps,
65
- **accelerate_kwargs,
66
- )
67
-
68
  if exists(wandb_resume_id):
69
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
70
  else:
@@ -86,24 +87,11 @@ class Trainer:
86
  "noise_scheduler": noise_scheduler,
87
  },
88
  )
 
89
  elif self.logger == "tensorboard":
90
  from torch.utils.tensorboard import SummaryWriter
91
 
92
- self.accelerator = Accelerator(
93
- kwargs_handlers=[ddp_kwargs],
94
- gradient_accumulation_steps=grad_accumulation_steps,
95
- **accelerate_kwargs,
96
- )
97
- if self.is_main:
98
- path_log_dir = os.path.join(log_dir, wandb_project)
99
- os.makedirs(path_log_dir, exist_ok=True)
100
- existing_folders = [folder for folder in os.listdir(path_log_dir) if folder.startswith("exp")]
101
- next_number = len(existing_folders) + 2
102
- folder_name = f"exp{next_number}"
103
- folder_path = os.path.join(path_log_dir, folder_name)
104
- os.makedirs(folder_path, exist_ok=True)
105
-
106
- self.writer = SummaryWriter(log_dir=folder_path)
107
 
108
  self.model = model
109
 
@@ -198,31 +186,13 @@ class Trainer:
198
  gc.collect()
199
  return step
200
 
201
- def log(self, metrics, step):
202
- """Unified logging method for both WandB and TensorBoard"""
203
- if self.logger == "none":
204
- return
205
- if self.logger == "wandb":
206
- self.accelerator.log(metrics, step=step)
207
- elif self.is_main:
208
- for key, value in metrics.items():
209
- self.writer.add_scalar(key, value, step)
210
-
211
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
212
- # import only when export_sample True
213
- if self.export_samples:
214
- from f5_tts.infer.utils_infer import (
215
- target_sample_rate,
216
- hop_length,
217
- nfe_step,
218
- cfg_strength,
219
- sway_sampling_coef,
220
- vocos,
221
- )
222
- from f5_tts.model.utils import get_sample
223
 
224
- self.file_path_samples = os.path.join(self.path_ckpts_project, "samples")
225
- os.makedirs(self.file_path_samples, exist_ok=True)
 
226
 
227
  if exists(resumable_with_seed):
228
  generator = torch.Generator()
@@ -307,7 +277,6 @@ class Trainer:
307
  for batch in progress_bar:
308
  with self.accelerator.accumulate(self.model):
309
  text_inputs = batch["text"]
310
-
311
  mel_spec = batch["mel"].permute(0, 2, 1)
312
  mel_lengths = batch["mel_lengths"]
313
 
@@ -319,40 +288,6 @@ class Trainer:
319
  loss, cond, pred = self.model(
320
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
321
  )
322
-
323
- # save 4 audio per save step
324
- if (
325
- self.accelerator.is_local_main_process
326
- and self.export_samples
327
- and global_step % (int(self.save_per_updates * 0.25) * self.grad_accumulation_steps) == 0
328
- ):
329
- try:
330
- wave_org, wave_gen, mel_org, mel_gen = get_sample(
331
- vocos,
332
- self.model,
333
- self.file_path_samples,
334
- global_step,
335
- batch["mel"][0],
336
- text_inputs,
337
- target_sample_rate,
338
- hop_length,
339
- nfe_step,
340
- cfg_strength,
341
- sway_sampling_coef,
342
- )
343
-
344
- if self.logger == "tensorboard":
345
- self.writer.add_audio(
346
- "Audio/original", wave_org, global_step, sample_rate=target_sample_rate
347
- )
348
- self.writer.add_audio(
349
- "Audio/generate", wave_gen, global_step, sample_rate=target_sample_rate
350
- )
351
- self.writer.add_image("Mel/original", mel_org, global_step, dataformats="CHW")
352
- self.writer.add_image("Mel/generate", mel_gen, global_step, dataformats="CHW")
353
- except Exception as e:
354
- print("An error occurred:", e)
355
-
356
  self.accelerator.backward(loss)
357
 
358
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
@@ -368,13 +303,32 @@ class Trainer:
368
  global_step += 1
369
 
370
  if self.accelerator.is_local_main_process:
371
- self.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
 
 
 
372
 
373
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
374
 
375
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
376
  self.save_checkpoint(global_step)
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  if global_step % self.last_per_steps == 0:
379
  self.save_checkpoint(global_step, last=True)
380
 
 
3
  import os
4
  import gc
5
  from tqdm import tqdm
6
+ import wandb
7
 
8
  import torch
9
+ import torchaudio
10
  from torch.optim import AdamW
11
  from torch.utils.data import DataLoader, Dataset, SequentialSampler
12
  from torch.optim.lr_scheduler import LinearLR, SequentialLR
 
20
  from f5_tts.model.utils import exists, default
21
  from f5_tts.model.dataset import DynamicBatchSampler, collate_fn
22
 
23
+
24
  # trainer
25
 
26
 
 
40
  max_grad_norm=1.0,
41
  noise_scheduler: str | None = None,
42
  duration_predictor: torch.nn.Module | None = None,
43
+ logger: str | None = "wandb", # "wandb" | "tensorboard" | None
 
44
  wandb_project="test_e2-tts",
45
  wandb_run_name="test_run",
46
  wandb_resume_id: str = None,
47
+ log_samples: bool = False,
48
  last_per_steps=None,
49
  accelerate_kwargs: dict = dict(),
50
  ema_kwargs: dict = dict(),
51
  bnb_optimizer: bool = False,
 
52
  ):
 
 
 
 
 
53
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
 
55
+ if logger == "wandb" and not wandb.api.api_key:
56
+ logger = None
57
+ print(f"Using logger: {logger}")
58
+ self.log_samples = log_samples
59
+
60
+ self.accelerator = Accelerator(
61
+ log_with=logger if logger == "wandb" else None,
62
+ kwargs_handlers=[ddp_kwargs],
63
+ gradient_accumulation_steps=grad_accumulation_steps,
64
+ **accelerate_kwargs,
65
+ )
66
+
67
  self.logger = logger
68
  if self.logger == "wandb":
 
 
 
 
 
 
 
69
  if exists(wandb_resume_id):
70
  init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
71
  else:
 
87
  "noise_scheduler": noise_scheduler,
88
  },
89
  )
90
+
91
  elif self.logger == "tensorboard":
92
  from torch.utils.tensorboard import SummaryWriter
93
 
94
+ self.writer = SummaryWriter(log_dir=f"runs/{wandb_run_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  self.model = model
97
 
 
186
  gc.collect()
187
  return step
188
 
 
 
 
 
 
 
 
 
 
 
189
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
190
+ if self.log_samples:
191
+ from f5_tts.infer.utils_infer import vocos, nfe_step, cfg_strength, sway_sampling_coef
 
 
 
 
 
 
 
 
 
192
 
193
+ target_sample_rate = self.model.mel_spec.mel_stft.sample_rate
194
+ log_samples_path = f"{self.checkpoint_path}/samples"
195
+ os.makedirs(log_samples_path, exist_ok=True)
196
 
197
  if exists(resumable_with_seed):
198
  generator = torch.Generator()
 
277
  for batch in progress_bar:
278
  with self.accelerator.accumulate(self.model):
279
  text_inputs = batch["text"]
 
280
  mel_spec = batch["mel"].permute(0, 2, 1)
281
  mel_lengths = batch["mel_lengths"]
282
 
 
288
  loss, cond, pred = self.model(
289
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  self.accelerator.backward(loss)
292
 
293
  if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
 
303
  global_step += 1
304
 
305
  if self.accelerator.is_local_main_process:
306
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
307
+ if self.logger == "tensorboard":
308
+ self.writer.add_scalar("loss", loss.item(), global_step)
309
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
310
 
311
  progress_bar.set_postfix(step=str(global_step), loss=loss.item())
312
 
313
  if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
314
  self.save_checkpoint(global_step)
315
 
316
+ if self.log_samples:
317
+ ref_audio, ref_audio_len = vocos.decode([batch["mel"][0]].cpu()), mel_lengths[0]
318
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
319
+ with torch.inference_mode():
320
+ generated, _ = self.model.sample(
321
+ cond=[mel_spec[0][:ref_audio_len]],
322
+ text=[text_inputs[0] + [" "] + text_inputs[0]],
323
+ duration=ref_audio_len * 2,
324
+ steps=nfe_step,
325
+ cfg_strength=cfg_strength,
326
+ sway_sampling_coef=sway_sampling_coef,
327
+ )
328
+ generated = generated.to(torch.float32)
329
+ gen_audio = vocos.decode(generated[:, ref_audio_len:, :].permute(0, 2, 1).cpu())
330
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
331
+
332
  if global_step % self.last_per_steps == 0:
333
  self.save_checkpoint(global_step, last=True)
334
 
src/f5_tts/model/utils.py CHANGED
@@ -11,10 +11,6 @@ from torch.nn.utils.rnn import pad_sequence
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
14
- import numpy as np
15
- import matplotlib.pyplot as plt
16
- import soundfile as sf
17
- import torchaudio
18
 
19
  # seed everything
20
 
@@ -187,74 +183,3 @@ def repetition_found(text, length=2, tolerance=10):
187
  if count > tolerance:
188
  return True
189
  return False
190
-
191
-
192
- def normalize_and_colorize_spectrogram(mel_org):
193
- mel_min, mel_max = mel_org.min(), mel_org.max()
194
- mel_norm = (mel_org - mel_min) / (mel_max - mel_min + 1e-8)
195
- mel_colored = plt.get_cmap("viridis")(mel_norm.detach().cpu().numpy())[:, :, :3]
196
- mel_colored = np.transpose(mel_colored, (2, 0, 1))
197
- mel_colored = np.flip(mel_colored, axis=1)
198
- return mel_colored
199
-
200
-
201
- def export_audio(file_out, wav, target_sample_rate):
202
- sf.write(file_out, wav, samplerate=target_sample_rate)
203
-
204
-
205
- def export_mel(mel_colored_hwc, file_out):
206
- plt.imsave(file_out, mel_colored_hwc)
207
-
208
-
209
- def gen_sample(model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef):
210
- audio, sr = torchaudio.load(file_wav_org)
211
- audio = audio.to("cuda")
212
- ref_audio_len = audio.shape[-1] // hop_length
213
- text = [text_inputs[0] + [" . "] + text_inputs[0]]
214
- duration = int((audio.shape[1] / 256) * 2.0)
215
- with torch.inference_mode():
216
- generated_gen, _ = model.sample(
217
- cond=audio,
218
- text=text,
219
- duration=duration,
220
- steps=nfe_step,
221
- cfg_strength=cfg_strength,
222
- sway_sampling_coef=sway_sampling_coef,
223
- )
224
- generated_gen = generated_gen.to(torch.float32)
225
- generated_gen = generated_gen[:, ref_audio_len:, :]
226
- generated_mel_spec_gen = generated_gen.permute(0, 2, 1)
227
- generated_wave_gen = vocos.decode(generated_mel_spec_gen.cpu())
228
- generated_wave_gen = generated_wave_gen.squeeze().cpu().numpy()
229
- return generated_wave_gen, generated_mel_spec_gen
230
-
231
-
232
- def get_sample(
233
- vocos,
234
- model,
235
- file_path_samples,
236
- global_step,
237
- mel_org,
238
- text_inputs,
239
- target_sample_rate,
240
- hop_length,
241
- nfe_step,
242
- cfg_strength,
243
- sway_sampling_coef,
244
- ):
245
- generated_wave_org = vocos.decode(mel_org.unsqueeze(0).cpu())
246
- generated_wave_org = generated_wave_org.squeeze().cpu().numpy()
247
- file_wav_org = os.path.join(file_path_samples, f"step_{global_step}_org.wav")
248
- export_audio(file_wav_org, generated_wave_org, target_sample_rate)
249
- generated_wave_gen, generated_mel_spec_gen = gen_sample(
250
- model, vocos, file_wav_org, text_inputs, hop_length, nfe_step, cfg_strength, sway_sampling_coef
251
- )
252
- file_wav_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.wav")
253
- export_audio(file_wav_gen, generated_wave_gen, target_sample_rate)
254
- mel_org = normalize_and_colorize_spectrogram(mel_org)
255
- mel_gen = normalize_and_colorize_spectrogram(generated_mel_spec_gen[0])
256
- file_gen_org = os.path.join(file_path_samples, f"step_{global_step}_org.png")
257
- export_mel(np.transpose(mel_org, (1, 2, 0)), file_gen_org)
258
- file_gen_gen = os.path.join(file_path_samples, f"step_{global_step}_gen.png")
259
- export_mel(np.transpose(mel_gen, (1, 2, 0)), file_gen_gen)
260
- return generated_wave_org, generated_wave_gen, mel_org, mel_gen
 
11
  import jieba
12
  from pypinyin import lazy_pinyin, Style
13
 
 
 
 
 
14
 
15
  # seed everything
16
 
 
183
  if count > tolerance:
184
  return True
185
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -57,12 +57,12 @@ def parse_args():
57
  )
58
 
59
  parser.add_argument(
60
- "--export_samples",
61
  type=bool,
62
  default=False,
63
- help="Export 4 audio and spect samples for the checkpoint audio, per step.",
64
  )
65
- parser.add_argument("--logger", type=str, default="wandb", choices=["none", "wandb", "tensorboard"], help="logger")
66
 
67
  return parser.parse_args()
68
 
@@ -141,12 +141,12 @@ def main():
141
  max_samples=args.max_samples,
142
  grad_accumulation_steps=args.grad_accumulation_steps,
143
  max_grad_norm=args.max_grad_norm,
 
144
  wandb_project=args.dataset_name,
145
  wandb_run_name=args.exp_name,
146
  wandb_resume_id=wandb_resume_id,
 
147
  last_per_steps=args.last_per_steps,
148
- logger=args.logger,
149
- export_samples=args.export_samples,
150
  )
151
 
152
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
57
  )
58
 
59
  parser.add_argument(
60
+ "--log_samples",
61
  type=bool,
62
  default=False,
63
+ help="Log inferenced samples per ckpt save steps",
64
  )
65
+ parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
66
 
67
  return parser.parse_args()
68
 
 
141
  max_samples=args.max_samples,
142
  grad_accumulation_steps=args.grad_accumulation_steps,
143
  max_grad_norm=args.max_grad_norm,
144
+ logger=args.logger,
145
  wandb_project=args.dataset_name,
146
  wandb_run_name=args.exp_name,
147
  wandb_resume_id=wandb_resume_id,
148
+ log_samples=args.log_samples,
149
  last_per_steps=args.last_per_steps,
 
 
150
  )
151
 
152
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -453,7 +453,7 @@ def start_training(
453
 
454
  cmd += f" --tokenizer {tokenizer_type} "
455
 
456
- cmd += f" --export_samples True --logger {logger} "
457
 
458
  print(cmd)
459
 
@@ -1321,18 +1321,14 @@ def get_combined_stats():
1321
 
1322
 
1323
  def get_audio_select(file_sample):
1324
- select_audio_org = file_sample
1325
  select_audio_gen = file_sample
1326
- select_image_org = file_sample
1327
- select_image_gen = file_sample
1328
 
1329
  if file_sample is not None:
1330
- select_audio_org += "_org.wav"
1331
  select_audio_gen += "_gen.wav"
1332
- select_image_org += "_org.png"
1333
- select_image_gen += "_gen.png"
1334
 
1335
- return select_audio_org, select_audio_gen, select_image_org, select_image_gen
1336
 
1337
 
1338
  with gr.Blocks() as app:
@@ -1515,7 +1511,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1515
 
1516
  with gr.Row():
1517
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1518
- cd_logger = gr.Radio(label="logger", choices=["none", "wandb", "tensorboard"], value="wandb")
1519
  start_button = gr.Button("Start Training")
1520
  stop_button = gr.Button("Stop Training", interactive=False)
1521
 
@@ -1562,16 +1558,12 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1562
 
1563
  list_audios, select_audio = get_audio_project(projects_selelect, False)
1564
 
1565
- select_audio_org = select_audio
1566
  select_audio_gen = select_audio
1567
- select_image_org = select_audio
1568
- select_image_gen = select_audio
1569
 
1570
  if select_audio is not None:
1571
- select_audio_org += "_org.wav"
1572
  select_audio_gen += "_gen.wav"
1573
- select_image_org += "_org.png"
1574
- select_image_gen += "_gen.png"
1575
 
1576
  with gr.Row():
1577
  ch_list_audio = gr.Dropdown(
@@ -1587,17 +1579,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1587
  cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1588
 
1589
  with gr.Row():
1590
- audio_org_stream = gr.Audio(label="original", type="filepath", value=select_audio_org)
1591
- mel_org_stream = gr.Image(label="original", type="filepath", value=select_image_org)
1592
-
1593
- with gr.Row():
1594
  audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
1595
- mel_gen_stream = gr.Image(label="generate", type="filepath", value=select_image_gen)
1596
 
1597
  ch_list_audio.change(
1598
  fn=get_audio_select,
1599
  inputs=[ch_list_audio],
1600
- outputs=[audio_org_stream, audio_gen_stream, mel_org_stream, mel_gen_stream],
1601
  )
1602
 
1603
  start_button.click(
 
453
 
454
  cmd += f" --tokenizer {tokenizer_type} "
455
 
456
+ cmd += f" --log_samples True --logger {logger} "
457
 
458
  print(cmd)
459
 
 
1321
 
1322
 
1323
  def get_audio_select(file_sample):
1324
+ select_audio_ref = file_sample
1325
  select_audio_gen = file_sample
 
 
1326
 
1327
  if file_sample is not None:
1328
+ select_audio_ref += "_ref.wav"
1329
  select_audio_gen += "_gen.wav"
 
 
1330
 
1331
+ return select_audio_ref, select_audio_gen
1332
 
1333
 
1334
  with gr.Blocks() as app:
 
1511
 
1512
  with gr.Row():
1513
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1514
+ cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1515
  start_button = gr.Button("Start Training")
1516
  stop_button = gr.Button("Stop Training", interactive=False)
1517
 
 
1558
 
1559
  list_audios, select_audio = get_audio_project(projects_selelect, False)
1560
 
1561
+ select_audio_ref = select_audio
1562
  select_audio_gen = select_audio
 
 
1563
 
1564
  if select_audio is not None:
1565
+ select_audio_ref += "_ref.wav"
1566
  select_audio_gen += "_gen.wav"
 
 
1567
 
1568
  with gr.Row():
1569
  ch_list_audio = gr.Dropdown(
 
1579
  cm_project.change(fn=get_audio_project, inputs=[cm_project], outputs=[ch_list_audio])
1580
 
1581
  with gr.Row():
1582
+ audio_ref_stream = gr.Audio(label="original", type="filepath", value=select_audio_ref)
 
 
 
1583
  audio_gen_stream = gr.Audio(label="generate", type="filepath", value=select_audio_gen)
 
1584
 
1585
  ch_list_audio.change(
1586
  fn=get_audio_select,
1587
  inputs=[ch_list_audio],
1588
+ outputs=[audio_ref_stream, audio_gen_stream],
1589
  )
1590
 
1591
  start_button.click(