SWivid commited on
Commit
7b27962
·
1 Parent(s): 957b832

0.4.0 fix gradient accumulation; change checkpointing logic to per_updates

Browse files
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "f5-tts"
7
- version = "0.3.4"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
 
4
 
5
  [project]
6
  name = "f5-tts"
7
+ version = "0.4.0"
8
  description = "F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching"
9
  readme = "README.md"
10
  license = {text = "MIT License"}
src/f5_tts/configs/E2TTS_Base_train.yaml CHANGED
@@ -12,7 +12,7 @@ datasets:
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -39,6 +39,6 @@ model:
39
 
40
  ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
- save_per_updates: 50000 # save checkpoint per steps
43
- last_per_steps: 5000 # save last checkpoint per steps
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
 
39
 
40
  ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per updates
43
+ last_per_updates: 5000 # save last checkpoint per updates
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/E2TTS_Small_train.yaml CHANGED
@@ -12,7 +12,7 @@ datasets:
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0
18
  bnb_optimizer: False
@@ -39,6 +39,6 @@ model:
39
 
40
  ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
- save_per_updates: 50000 # save checkpoint per steps
43
- last_per_steps: 5000 # save last checkpoint per steps
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0
18
  bnb_optimizer: False
 
39
 
40
  ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
+ save_per_updates: 50000 # save checkpoint per updates
43
+ last_per_updates: 5000 # save last checkpoint per updates
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -12,7 +12,7 @@ datasets:
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -42,6 +42,6 @@ model:
42
 
43
  ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
- save_per_updates: 50000 # save checkpoint per steps
46
- last_per_steps: 5000 # save last checkpoint per steps
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
 
42
 
43
  ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
+ save_per_updates: 50000 # save checkpoint per updates
46
+ last_per_updates: 5000 # save last checkpoint per updates
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -12,7 +12,7 @@ datasets:
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
- num_warmup_updates: 20000 # warmup steps
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
@@ -42,6 +42,6 @@ model:
42
 
43
  ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
- save_per_updates: 50000 # save checkpoint per steps
46
- last_per_steps: 5000 # save last checkpoint per steps
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
12
  optim:
13
  epochs: 15
14
  learning_rate: 7.5e-5
15
+ num_warmup_updates: 20000 # warmup updates
16
  grad_accumulation_steps: 1 # note: updates = steps / grad_accumulation_steps
17
  max_grad_norm: 1.0 # gradient clipping
18
  bnb_optimizer: False # use bnb 8bit AdamW optimizer or not
 
42
 
43
  ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
+ save_per_updates: 50000 # save checkpoint per updates
46
+ last_per_updates: 5000 # save last checkpoint per updates
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/model/trainer.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import gc
 
4
  import os
5
 
6
  import torch
@@ -42,7 +43,7 @@ class Trainer:
42
  wandb_run_name="test_run",
43
  wandb_resume_id: str = None,
44
  log_samples: bool = False,
45
- last_per_steps=None,
46
  accelerate_kwargs: dict = dict(),
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
@@ -57,6 +58,11 @@ class Trainer:
57
  print(f"Using logger: {logger}")
58
  self.log_samples = log_samples
59
 
 
 
 
 
 
60
  self.accelerator = Accelerator(
61
  log_with=logger if logger == "wandb" else None,
62
  kwargs_handlers=[ddp_kwargs],
@@ -102,7 +108,7 @@ class Trainer:
102
  self.epochs = epochs
103
  self.num_warmup_updates = num_warmup_updates
104
  self.save_per_updates = save_per_updates
105
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
106
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
107
 
108
  self.batch_size = batch_size
@@ -132,7 +138,7 @@ class Trainer:
132
  def is_main(self):
133
  return self.accelerator.is_main_process
134
 
135
- def save_checkpoint(self, step, last=False):
136
  self.accelerator.wait_for_everyone()
137
  if self.is_main:
138
  checkpoint = dict(
@@ -140,15 +146,15 @@ class Trainer:
140
  optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
141
  ema_model_state_dict=self.ema_model.state_dict(),
142
  scheduler_state_dict=self.scheduler.state_dict(),
143
- step=step,
144
  )
145
  if not os.path.exists(self.checkpoint_path):
146
  os.makedirs(self.checkpoint_path)
147
  if last:
148
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
149
- print(f"Saved last checkpoint at step {step}")
150
  else:
151
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
152
 
153
  def load_checkpoint(self):
154
  if (
@@ -177,7 +183,14 @@ class Trainer:
177
  if self.is_main:
178
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
179
 
180
- if "step" in checkpoint:
 
 
 
 
 
 
 
181
  # patch for backward compatibility, 305e3ea
182
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
183
  if key in checkpoint["model_state_dict"]:
@@ -187,19 +200,19 @@ class Trainer:
187
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
188
  if self.scheduler:
189
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
190
- step = checkpoint["step"]
191
  else:
192
  checkpoint["model_state_dict"] = {
193
  k.replace("ema_model.", ""): v
194
  for k, v in checkpoint["ema_model_state_dict"].items()
195
- if k not in ["initted", "step"]
196
  }
197
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
198
- step = 0
199
 
200
  del checkpoint
201
  gc.collect()
202
- return step
203
 
204
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
205
  if self.log_samples:
@@ -248,25 +261,26 @@ class Trainer:
248
 
249
  # accelerator.prepare() dispatches batches to devices;
250
  # which means the length of dataloader calculated before, should consider the number of devices
251
- warmup_steps = (
252
  self.num_warmup_updates * self.accelerator.num_processes
253
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
254
  # otherwise by default with split_batches=False, warmup steps change with num_processes
255
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
256
- decay_steps = total_steps - warmup_steps
257
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
258
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
259
  self.scheduler = SequentialLR(
260
- self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
261
  )
262
  train_dataloader, self.scheduler = self.accelerator.prepare(
263
  train_dataloader, self.scheduler
264
- ) # actual steps = 1 gpu steps / gpus
265
- start_step = self.load_checkpoint()
266
- global_step = start_step
267
 
268
  if exists(resumable_with_seed):
269
  orig_epoch_step = len(train_dataloader)
 
270
  skipped_epoch = int(start_step // orig_epoch_step)
271
  skipped_batch = start_step % orig_epoch_step
272
  skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
@@ -276,23 +290,21 @@ class Trainer:
276
  for epoch in range(skipped_epoch, self.epochs):
277
  self.model.train()
278
  if exists(resumable_with_seed) and epoch == skipped_epoch:
279
- progress_bar = tqdm(
280
- skipped_dataloader,
281
- desc=f"Epoch {epoch+1}/{self.epochs}",
282
- unit="step",
283
- disable=not self.accelerator.is_local_main_process,
284
- initial=skipped_batch,
285
- total=orig_epoch_step,
286
- )
287
  else:
288
- progress_bar = tqdm(
289
- train_dataloader,
290
- desc=f"Epoch {epoch+1}/{self.epochs}",
291
- unit="step",
292
- disable=not self.accelerator.is_local_main_process,
293
- )
294
-
295
- for batch in progress_bar:
 
 
 
 
296
  with self.accelerator.accumulate(self.model):
297
  text_inputs = batch["text"]
298
  mel_spec = batch["mel"].permute(0, 2, 1)
@@ -301,7 +313,7 @@ class Trainer:
301
  # TODO. add duration predictor training
302
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
303
  dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
304
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
305
 
306
  loss, cond, pred = self.model(
307
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
@@ -318,18 +330,20 @@ class Trainer:
318
  if self.is_main and self.accelerator.sync_gradients:
319
  self.ema_model.update()
320
 
321
- global_step += 1
 
 
322
 
323
  if self.accelerator.is_local_main_process:
324
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
 
 
325
  if self.logger == "tensorboard":
326
- self.writer.add_scalar("loss", loss.item(), global_step)
327
- self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_step)
328
-
329
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
330
 
331
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
332
- self.save_checkpoint(global_step)
333
 
334
  if self.log_samples and self.accelerator.is_local_main_process:
335
  ref_audio_len = mel_lengths[0]
@@ -355,12 +369,16 @@ class Trainer:
355
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
356
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
357
 
358
- torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
359
- torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
 
 
 
 
360
 
361
- if global_step % self.last_per_steps == 0:
362
- self.save_checkpoint(global_step, last=True)
363
 
364
- self.save_checkpoint(global_step, last=True)
365
 
366
  self.accelerator.end_training()
 
1
  from __future__ import annotations
2
 
3
  import gc
4
+ import math
5
  import os
6
 
7
  import torch
 
43
  wandb_run_name="test_run",
44
  wandb_resume_id: str = None,
45
  log_samples: bool = False,
46
+ last_per_updates=None,
47
  accelerate_kwargs: dict = dict(),
48
  ema_kwargs: dict = dict(),
49
  bnb_optimizer: bool = False,
 
58
  print(f"Using logger: {logger}")
59
  self.log_samples = log_samples
60
 
61
+ if grad_accumulation_steps > 1 and self.is_main:
62
+ print(
63
+ "Gradient accumulation checkpointing with per_updates now, old logic per_steps used with before f992c4e"
64
+ )
65
+
66
  self.accelerator = Accelerator(
67
  log_with=logger if logger == "wandb" else None,
68
  kwargs_handlers=[ddp_kwargs],
 
108
  self.epochs = epochs
109
  self.num_warmup_updates = num_warmup_updates
110
  self.save_per_updates = save_per_updates
111
+ self.last_per_updates = default(last_per_updates, save_per_updates)
112
  self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")
113
 
114
  self.batch_size = batch_size
 
138
  def is_main(self):
139
  return self.accelerator.is_main_process
140
 
141
+ def save_checkpoint(self, update, last=False):
142
  self.accelerator.wait_for_everyone()
143
  if self.is_main:
144
  checkpoint = dict(
 
146
  optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
147
  ema_model_state_dict=self.ema_model.state_dict(),
148
  scheduler_state_dict=self.scheduler.state_dict(),
149
+ update=update,
150
  )
151
  if not os.path.exists(self.checkpoint_path):
152
  os.makedirs(self.checkpoint_path)
153
  if last:
154
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
155
+ print(f"Saved last checkpoint at update {update}")
156
  else:
157
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
158
 
159
  def load_checkpoint(self):
160
  if (
 
183
  if self.is_main:
184
  self.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
185
 
186
+ if "update" in checkpoint or "step" in checkpoint:
187
+ # patch for backward compatibility, with before f992c4e
188
+ if "step" in checkpoint:
189
+ checkpoint["update"] = checkpoint["step"] // self.grad_accumulation_steps
190
+ if self.grad_accumulation_steps > 1 and self.is_main:
191
+ print(
192
+ "F5-TTS WARNING: Loading checkpoint saved with per_steps logic (before f992c4e), will convert to per_updates according to grad_accumulation_steps setting, may have unexpected behaviour."
193
+ )
194
  # patch for backward compatibility, 305e3ea
195
  for key in ["mel_spec.mel_stft.mel_scale.fb", "mel_spec.mel_stft.spectrogram.window"]:
196
  if key in checkpoint["model_state_dict"]:
 
200
  self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint["optimizer_state_dict"])
201
  if self.scheduler:
202
  self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
203
+ update = checkpoint["update"]
204
  else:
205
  checkpoint["model_state_dict"] = {
206
  k.replace("ema_model.", ""): v
207
  for k, v in checkpoint["ema_model_state_dict"].items()
208
+ if k not in ["initted", "update", "step"]
209
  }
210
  self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint["model_state_dict"])
211
+ update = 0
212
 
213
  del checkpoint
214
  gc.collect()
215
+ return update
216
 
217
  def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
218
  if self.log_samples:
 
261
 
262
  # accelerator.prepare() dispatches batches to devices;
263
  # which means the length of dataloader calculated before, should consider the number of devices
264
+ warmup_updates = (
265
  self.num_warmup_updates * self.accelerator.num_processes
266
  ) # consider a fixed warmup steps while using accelerate multi-gpu ddp
267
  # otherwise by default with split_batches=False, warmup steps change with num_processes
268
+ total_updates = math.ceil(len(train_dataloader) / self.grad_accumulation_steps) * self.epochs
269
+ decay_updates = total_updates - warmup_updates
270
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_updates)
271
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_updates)
272
  self.scheduler = SequentialLR(
273
+ self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_updates]
274
  )
275
  train_dataloader, self.scheduler = self.accelerator.prepare(
276
  train_dataloader, self.scheduler
277
+ ) # actual multi_gpu updates = single_gpu updates / gpu nums
278
+ start_update = self.load_checkpoint()
279
+ global_update = start_update
280
 
281
  if exists(resumable_with_seed):
282
  orig_epoch_step = len(train_dataloader)
283
+ start_step = start_update * self.grad_accumulation_steps
284
  skipped_epoch = int(start_step // orig_epoch_step)
285
  skipped_batch = start_step % orig_epoch_step
286
  skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
 
290
  for epoch in range(skipped_epoch, self.epochs):
291
  self.model.train()
292
  if exists(resumable_with_seed) and epoch == skipped_epoch:
293
+ progress_bar_initial = math.ceil(skipped_batch / self.grad_accumulation_steps)
294
+ current_dataloader = skipped_dataloader
 
 
 
 
 
 
295
  else:
296
+ progress_bar_initial = 0
297
+ current_dataloader = train_dataloader
298
+
299
+ progress_bar = tqdm(
300
+ range(math.ceil(len(train_dataloader) / self.grad_accumulation_steps)),
301
+ desc=f"Epoch {epoch+1}/{self.epochs}",
302
+ unit="update",
303
+ disable=not self.accelerator.is_local_main_process,
304
+ initial=progress_bar_initial,
305
+ )
306
+
307
+ for batch in current_dataloader:
308
  with self.accelerator.accumulate(self.model):
309
  text_inputs = batch["text"]
310
  mel_spec = batch["mel"].permute(0, 2, 1)
 
313
  # TODO. add duration predictor training
314
  if self.duration_predictor is not None and self.accelerator.is_local_main_process:
315
  dur_loss = self.duration_predictor(mel_spec, lens=batch.get("durations"))
316
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_update)
317
 
318
  loss, cond, pred = self.model(
319
  mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler
 
330
  if self.is_main and self.accelerator.sync_gradients:
331
  self.ema_model.update()
332
 
333
+ global_update += 1
334
+ progress_bar.update(1)
335
+ progress_bar.set_postfix(update=str(global_update), loss=loss.item())
336
 
337
  if self.accelerator.is_local_main_process:
338
+ self.accelerator.log(
339
+ {"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_update
340
+ )
341
  if self.logger == "tensorboard":
342
+ self.writer.add_scalar("loss", loss.item(), global_update)
343
+ self.writer.add_scalar("lr", self.scheduler.get_last_lr()[0], global_update)
 
 
344
 
345
+ if global_update % self.save_per_updates == 0:
346
+ self.save_checkpoint(global_update)
347
 
348
  if self.log_samples and self.accelerator.is_local_main_process:
349
  ref_audio_len = mel_lengths[0]
 
369
  gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
370
  ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
371
 
372
+ torchaudio.save(
373
+ f"{log_samples_path}/update_{global_update}_gen.wav", gen_audio, target_sample_rate
374
+ )
375
+ torchaudio.save(
376
+ f"{log_samples_path}/update_{global_update}_ref.wav", ref_audio, target_sample_rate
377
+ )
378
 
379
+ if global_update % self.last_per_updates == 0:
380
+ self.save_checkpoint(global_update, last=True)
381
 
382
+ self.save_checkpoint(global_update, last=True)
383
 
384
  self.accelerator.end_training()
src/f5_tts/scripts/count_max_epoch.py CHANGED
@@ -20,13 +20,13 @@ grad_accum = 1
20
  mini_batch_frames = frames_per_gpu * grad_accum * gpus
21
  mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22
  updates_per_epoch = total_hours / mini_batch_hours
23
- steps_per_epoch = updates_per_epoch * grad_accum
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
  print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
- print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
31
  # others
32
  print(f"total {total_hours:.0f} hours")
 
20
  mini_batch_frames = frames_per_gpu * grad_accum * gpus
21
  mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
22
  updates_per_epoch = total_hours / mini_batch_hours
23
+ # steps_per_epoch = updates_per_epoch * grad_accum
24
 
25
  # result
26
  epochs = wanted_max_updates / updates_per_epoch
27
  print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
28
  print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
29
+ # print(f" or approx. 0/{steps_per_epoch:.0f} steps")
30
 
31
  # others
32
  print(f"total {total_hours:.0f} hours")
src/f5_tts/train/finetune_cli.py CHANGED
@@ -27,7 +27,7 @@ def parse_args():
27
 
28
  # num_warmup_updates = 300 for 5000 sample about 10 hours
29
 
30
- # change save_per_updates , last_per_steps change this value what you need ,
31
 
32
  parser = argparse.ArgumentParser(description="Train CFM Model")
33
 
@@ -44,9 +44,9 @@ def parse_args():
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
  parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
- parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup steps")
48
- parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
49
- parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
50
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
51
  parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
52
  parser.add_argument(
@@ -61,7 +61,7 @@ def parse_args():
61
  parser.add_argument(
62
  "--log_samples",
63
  action="store_true",
64
- help="Log inferenced samples per ckpt save steps",
65
  )
66
  parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
67
  parser.add_argument(
@@ -156,7 +156,7 @@ def main():
156
  wandb_run_name=args.exp_name,
157
  wandb_resume_id=wandb_resume_id,
158
  log_samples=args.log_samples,
159
- last_per_steps=args.last_per_steps,
160
  bnb_optimizer=args.bnb_optimizer,
161
  )
162
 
 
27
 
28
  # num_warmup_updates = 300 for 5000 sample about 10 hours
29
 
30
+ # change save_per_updates , last_per_updates change this value what you need ,
31
 
32
  parser = argparse.ArgumentParser(description="Train CFM Model")
33
 
 
44
  parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
45
  parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
46
  parser.add_argument("--epochs", type=int, default=100, help="Number of training epochs")
47
+ parser.add_argument("--num_warmup_updates", type=int, default=300, help="Warmup updates")
48
+ parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X updates")
49
+ parser.add_argument("--last_per_updates", type=int, default=50000, help="Save last checkpoint every X updates")
50
  parser.add_argument("--finetune", action="store_true", help="Use Finetune")
51
  parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
52
  parser.add_argument(
 
61
  parser.add_argument(
62
  "--log_samples",
63
  action="store_true",
64
+ help="Log inferenced samples per ckpt save updates",
65
  )
66
  parser.add_argument("--logger", type=str, default=None, choices=["wandb", "tensorboard"], help="logger")
67
  parser.add_argument(
 
156
  wandb_run_name=args.exp_name,
157
  wandb_resume_id=wandb_resume_id,
158
  log_samples=args.log_samples,
159
+ last_per_updates=args.last_per_updates,
160
  bnb_optimizer=args.bnb_optimizer,
161
  )
162
 
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -62,7 +62,7 @@ def save_settings(
62
  epochs,
63
  num_warmup_updates,
64
  save_per_updates,
65
- last_per_steps,
66
  finetune,
67
  file_checkpoint_train,
68
  tokenizer_type,
@@ -86,7 +86,7 @@ def save_settings(
86
  "epochs": epochs,
87
  "num_warmup_updates": num_warmup_updates,
88
  "save_per_updates": save_per_updates,
89
- "last_per_steps": last_per_steps,
90
  "finetune": finetune,
91
  "file_checkpoint_train": file_checkpoint_train,
92
  "tokenizer_type": tokenizer_type,
@@ -118,7 +118,7 @@ def load_settings(project_name):
118
  "epochs": 100,
119
  "num_warmup_updates": 2,
120
  "save_per_updates": 300,
121
- "last_per_steps": 100,
122
  "finetune": True,
123
  "file_checkpoint_train": "",
124
  "tokenizer_type": "pinyin",
@@ -138,7 +138,7 @@ def load_settings(project_name):
138
  settings["epochs"],
139
  settings["num_warmup_updates"],
140
  settings["save_per_updates"],
141
- settings["last_per_steps"],
142
  settings["finetune"],
143
  settings["file_checkpoint_train"],
144
  settings["tokenizer_type"],
@@ -154,6 +154,8 @@ def load_settings(project_name):
154
  settings["logger"] = "wandb"
155
  if "bnb_optimizer" not in settings:
156
  settings["bnb_optimizer"] = False
 
 
157
  return (
158
  settings["exp_name"],
159
  settings["learning_rate"],
@@ -165,7 +167,7 @@ def load_settings(project_name):
165
  settings["epochs"],
166
  settings["num_warmup_updates"],
167
  settings["save_per_updates"],
168
- settings["last_per_steps"],
169
  settings["finetune"],
170
  settings["file_checkpoint_train"],
171
  settings["tokenizer_type"],
@@ -379,7 +381,7 @@ def start_training(
379
  epochs=11,
380
  num_warmup_updates=200,
381
  save_per_updates=400,
382
- last_per_steps=800,
383
  finetune=True,
384
  file_checkpoint_train="",
385
  tokenizer_type="pinyin",
@@ -448,7 +450,7 @@ def start_training(
448
  f"--epochs {epochs} "
449
  f"--num_warmup_updates {num_warmup_updates} "
450
  f"--save_per_updates {save_per_updates} "
451
- f"--last_per_steps {last_per_steps} "
452
  f"--dataset_name {dataset_name}"
453
  )
454
 
@@ -482,7 +484,7 @@ def start_training(
482
  epochs,
483
  num_warmup_updates,
484
  save_per_updates,
485
- last_per_steps,
486
  finetune,
487
  file_checkpoint_train,
488
  tokenizer_type,
@@ -880,7 +882,7 @@ def calculate_train(
880
  learning_rate,
881
  num_warmup_updates,
882
  save_per_updates,
883
- last_per_steps,
884
  finetune,
885
  ):
886
  path_project = os.path.join(path_data, name_project)
@@ -892,7 +894,7 @@ def calculate_train(
892
  max_samples,
893
  num_warmup_updates,
894
  save_per_updates,
895
- last_per_steps,
896
  "project not found !",
897
  learning_rate,
898
  )
@@ -940,14 +942,14 @@ def calculate_train(
940
 
941
  num_warmup_updates = int(samples * 0.05)
942
  save_per_updates = int(samples * 0.10)
943
- last_per_steps = int(save_per_updates * 0.25)
944
 
945
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
946
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
947
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
948
- last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
949
- if last_per_steps <= 0:
950
- last_per_steps = 2
951
 
952
  total_hours = hours
953
  mel_hop_length = 256
@@ -978,7 +980,7 @@ def calculate_train(
978
  max_samples,
979
  num_warmup_updates,
980
  save_per_updates,
981
- last_per_steps,
982
  samples,
983
  learning_rate,
984
  int(epochs),
@@ -1530,7 +1532,7 @@ Skip this step if you have your dataset, raw.arrow, duration.json, and vocab.txt
1530
 
1531
  with gr.TabItem("Train Data"):
1532
  gr.Markdown("""```plaintext
1533
- The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per steps are set correctly, or change them manually as needed.
1534
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1535
  ```""")
1536
  with gr.Row():
@@ -1561,7 +1563,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1561
 
1562
  with gr.Row():
1563
  save_per_updates = gr.Number(label="Save per Updates", value=300)
1564
- last_per_steps = gr.Number(label="Last per Steps", value=100)
1565
 
1566
  with gr.Row():
1567
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
@@ -1582,7 +1584,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1582
  epochsv,
1583
  num_warmupv_updatesv,
1584
  save_per_updatesv,
1585
- last_per_stepsv,
1586
  finetunev,
1587
  file_checkpoint_trainv,
1588
  tokenizer_typev,
@@ -1601,7 +1603,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1601
  epochs.value = epochsv
1602
  num_warmup_updates.value = num_warmupv_updatesv
1603
  save_per_updates.value = save_per_updatesv
1604
- last_per_steps.value = last_per_stepsv
1605
  ch_finetune.value = finetunev
1606
  file_checkpoint_train.value = file_checkpoint_trainv
1607
  tokenizer_type.value = tokenizer_typev
@@ -1659,7 +1661,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1659
  epochs,
1660
  num_warmup_updates,
1661
  save_per_updates,
1662
- last_per_steps,
1663
  ch_finetune,
1664
  file_checkpoint_train,
1665
  tokenizer_type,
@@ -1682,7 +1684,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1682
  learning_rate,
1683
  num_warmup_updates,
1684
  save_per_updates,
1685
- last_per_steps,
1686
  ch_finetune,
1687
  ],
1688
  outputs=[
@@ -1690,7 +1692,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1690
  max_samples,
1691
  num_warmup_updates,
1692
  save_per_updates,
1693
- last_per_steps,
1694
  lb_samples,
1695
  learning_rate,
1696
  epochs,
@@ -1713,7 +1715,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1713
  epochs,
1714
  num_warmup_updates,
1715
  save_per_updates,
1716
- last_per_steps,
1717
  ch_finetune,
1718
  file_checkpoint_train,
1719
  tokenizer_type,
 
62
  epochs,
63
  num_warmup_updates,
64
  save_per_updates,
65
+ last_per_updates,
66
  finetune,
67
  file_checkpoint_train,
68
  tokenizer_type,
 
86
  "epochs": epochs,
87
  "num_warmup_updates": num_warmup_updates,
88
  "save_per_updates": save_per_updates,
89
+ "last_per_updates": last_per_updates,
90
  "finetune": finetune,
91
  "file_checkpoint_train": file_checkpoint_train,
92
  "tokenizer_type": tokenizer_type,
 
118
  "epochs": 100,
119
  "num_warmup_updates": 2,
120
  "save_per_updates": 300,
121
+ "last_per_updates": 100,
122
  "finetune": True,
123
  "file_checkpoint_train": "",
124
  "tokenizer_type": "pinyin",
 
138
  settings["epochs"],
139
  settings["num_warmup_updates"],
140
  settings["save_per_updates"],
141
+ settings["last_per_updates"],
142
  settings["finetune"],
143
  settings["file_checkpoint_train"],
144
  settings["tokenizer_type"],
 
154
  settings["logger"] = "wandb"
155
  if "bnb_optimizer" not in settings:
156
  settings["bnb_optimizer"] = False
157
+ if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
158
+ settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
159
  return (
160
  settings["exp_name"],
161
  settings["learning_rate"],
 
167
  settings["epochs"],
168
  settings["num_warmup_updates"],
169
  settings["save_per_updates"],
170
+ settings["last_per_updates"],
171
  settings["finetune"],
172
  settings["file_checkpoint_train"],
173
  settings["tokenizer_type"],
 
381
  epochs=11,
382
  num_warmup_updates=200,
383
  save_per_updates=400,
384
+ last_per_updates=800,
385
  finetune=True,
386
  file_checkpoint_train="",
387
  tokenizer_type="pinyin",
 
450
  f"--epochs {epochs} "
451
  f"--num_warmup_updates {num_warmup_updates} "
452
  f"--save_per_updates {save_per_updates} "
453
+ f"--last_per_updates {last_per_updates} "
454
  f"--dataset_name {dataset_name}"
455
  )
456
 
 
484
  epochs,
485
  num_warmup_updates,
486
  save_per_updates,
487
+ last_per_updates,
488
  finetune,
489
  file_checkpoint_train,
490
  tokenizer_type,
 
882
  learning_rate,
883
  num_warmup_updates,
884
  save_per_updates,
885
+ last_per_updates,
886
  finetune,
887
  ):
888
  path_project = os.path.join(path_data, name_project)
 
894
  max_samples,
895
  num_warmup_updates,
896
  save_per_updates,
897
+ last_per_updates,
898
  "project not found !",
899
  learning_rate,
900
  )
 
942
 
943
  num_warmup_updates = int(samples * 0.05)
944
  save_per_updates = int(samples * 0.10)
945
+ last_per_updates = int(save_per_updates * 0.25)
946
 
947
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
948
  num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
949
  save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
950
+ last_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_updates)
951
+ if last_per_updates <= 0:
952
+ last_per_updates = 2
953
 
954
  total_hours = hours
955
  mel_hop_length = 256
 
980
  max_samples,
981
  num_warmup_updates,
982
  save_per_updates,
983
+ last_per_updates,
984
  samples,
985
  learning_rate,
986
  int(epochs),
 
1532
 
1533
  with gr.TabItem("Train Data"):
1534
  gr.Markdown("""```plaintext
1535
+ The auto-setting is still experimental. Please make sure that the epochs, save per updates, and last per updates are set correctly, or change them manually as needed.
1536
  If you encounter a memory error, try reducing the batch size per GPU to a smaller number.
1537
  ```""")
1538
  with gr.Row():
 
1563
 
1564
  with gr.Row():
1565
  save_per_updates = gr.Number(label="Save per Updates", value=300)
1566
+ last_per_updates = gr.Number(label="Last per Updates", value=100)
1567
 
1568
  with gr.Row():
1569
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
 
1584
  epochsv,
1585
  num_warmupv_updatesv,
1586
  save_per_updatesv,
1587
+ last_per_updatesv,
1588
  finetunev,
1589
  file_checkpoint_trainv,
1590
  tokenizer_typev,
 
1603
  epochs.value = epochsv
1604
  num_warmup_updates.value = num_warmupv_updatesv
1605
  save_per_updates.value = save_per_updatesv
1606
+ last_per_updates.value = last_per_updatesv
1607
  ch_finetune.value = finetunev
1608
  file_checkpoint_train.value = file_checkpoint_trainv
1609
  tokenizer_type.value = tokenizer_typev
 
1661
  epochs,
1662
  num_warmup_updates,
1663
  save_per_updates,
1664
+ last_per_updates,
1665
  ch_finetune,
1666
  file_checkpoint_train,
1667
  tokenizer_type,
 
1684
  learning_rate,
1685
  num_warmup_updates,
1686
  save_per_updates,
1687
+ last_per_updates,
1688
  ch_finetune,
1689
  ],
1690
  outputs=[
 
1692
  max_samples,
1693
  num_warmup_updates,
1694
  save_per_updates,
1695
+ last_per_updates,
1696
  lb_samples,
1697
  learning_rate,
1698
  epochs,
 
1715
  epochs,
1716
  num_warmup_updates,
1717
  save_per_updates,
1718
+ last_per_updates,
1719
  ch_finetune,
1720
  file_checkpoint_train,
1721
  tokenizer_type,
src/f5_tts/train/train.py CHANGED
@@ -55,7 +55,7 @@ def main(cfg):
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
- last_per_steps=cfg.ckpts.last_per_steps,
59
  log_samples=True,
60
  bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,
 
55
  wandb_project="CFM-TTS",
56
  wandb_run_name=exp_name,
57
  wandb_resume_id=wandb_resume_id,
58
+ last_per_updates=cfg.ckpts.last_per_updates,
59
  log_samples=True,
60
  bnb_optimizer=cfg.optim.bnb_optimizer,
61
  mel_spec_type=mel_spec_type,