Hasan Can Solakoğlu commited on
Commit
2d70bfe
·
1 Parent(s): b09897b

Keep Last N Checkpoints (#718)

Browse files

* Add checkpoint management feature

- Introduced `keep_last_n_checkpoints` parameter in configuration and training scripts to manage the number of recent checkpoints retained.
- Updated `finetune_cli.py`, `finetune_gradio.py`, and `trainer.py` to support this new parameter.
- Implemented logic to remove older checkpoints beyond the specified limit during training.
- Adjusted settings loading and saving to include the new checkpoint management option.

This enhancement improves the training process by preventing excessive storage usage from old checkpoints.

src/f5_tts/configs/E2TTS_Base_train.yaml CHANGED
@@ -41,4 +41,5 @@ ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per updates
43
  last_per_updates: 5000 # save last checkpoint per updates
 
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per updates
43
  last_per_updates: 5000 # save last checkpoint per updates
44
+ keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
45
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/E2TTS_Small_train.yaml CHANGED
@@ -41,4 +41,5 @@ ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per updates
43
  last_per_updates: 5000 # save last checkpoint per updates
 
44
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per updates
43
  last_per_updates: 5000 # save last checkpoint per updates
44
+ keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
45
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -44,4 +44,5 @@ ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
  save_per_updates: 50000 # save checkpoint per updates
46
  last_per_updates: 5000 # save last checkpoint per updates
 
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
44
  logger: wandb # wandb | tensorboard | None
45
  save_per_updates: 50000 # save checkpoint per updates
46
  last_per_updates: 5000 # save last checkpoint per updates
47
+ keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
48
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -44,4 +44,5 @@ ckpts:
44
  logger: wandb # wandb | tensorboard | None
45
  save_per_updates: 50000 # save checkpoint per updates
46
  last_per_updates: 5000 # save last checkpoint per updates
 
47
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
 
44
  logger: wandb # wandb | tensorboard | None
45
  save_per_updates: 50000 # save checkpoint per updates
46
  last_per_updates: 5000 # save last checkpoint per updates
47
+ keep_last_n_checkpoints: -1 # -1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints
48
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/model/trainer.py CHANGED
@@ -50,7 +50,17 @@ class Trainer:
50
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
51
  is_local_vocoder: bool = False, # use local path vocoder
52
  local_vocoder_path: str = "", # local vocoder path
 
 
53
  ):
 
 
 
 
 
 
 
 
54
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
55
 
56
  if logger == "wandb" and not wandb.api.api_key:
@@ -134,6 +144,8 @@ class Trainer:
134
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
135
  self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
136
 
 
 
137
  @property
138
  def is_main(self):
139
  return self.accelerator.is_main_process
@@ -154,7 +166,26 @@ class Trainer:
154
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
155
  print(f"Saved last checkpoint at update {update}")
156
  else:
 
 
 
 
157
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def load_checkpoint(self):
160
  if (
 
50
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
51
  is_local_vocoder: bool = False, # use local path vocoder
52
  local_vocoder_path: str = "", # local vocoder path
53
+ keep_last_n_checkpoints: int
54
+ | None = -1, # -1 (default) to keep all, 0 to not save intermediate ckpts, positive N to keep last N checkpoints
55
  ):
56
+ # Validate keep_last_n_checkpoints
57
+ if not isinstance(keep_last_n_checkpoints, int):
58
+ raise ValueError("keep_last_n_checkpoints must be an integer")
59
+ if keep_last_n_checkpoints < -1:
60
+ raise ValueError(
61
+ "keep_last_n_checkpoints must be -1 (keep all), 0 (no intermediate checkpoints), or positive integer"
62
+ )
63
+
64
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
65
 
66
  if logger == "wandb" and not wandb.api.api_key:
 
144
  self.optimizer = AdamW(model.parameters(), lr=learning_rate)
145
  self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
146
 
147
+ self.keep_last_n_checkpoints = keep_last_n_checkpoints if keep_last_n_checkpoints is not None else None
148
+
149
  @property
150
  def is_main(self):
151
  return self.accelerator.is_main_process
 
166
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
167
  print(f"Saved last checkpoint at update {update}")
168
  else:
169
+ # Skip saving intermediate checkpoints if keep_last_n_checkpoints is 0
170
+ if self.keep_last_n_checkpoints == 0:
171
+ return
172
+
173
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
174
+ # Implement rolling checkpoint system - only if keep_last_n_checkpoints is positive
175
+ if self.keep_last_n_checkpoints > 0:
176
+ # Get all checkpoint files except model_last.pt
177
+ checkpoints = [
178
+ f
179
+ for f in os.listdir(self.checkpoint_path)
180
+ if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
181
+ ]
182
+ # Sort by step number
183
+ checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
184
+ # Remove old checkpoints if we have more than keep_last_n_checkpoints
185
+ while len(checkpoints) > self.keep_last_n_checkpoints:
186
+ oldest_checkpoint = checkpoints.pop(0)
187
+ os.remove(os.path.join(self.checkpoint_path, oldest_checkpoint))
188
+ print(f"Removed old checkpoint: {oldest_checkpoint}")
189
 
190
  def load_checkpoint(self):
191
  if (
src/f5_tts/train/finetune_cli.py CHANGED
@@ -69,6 +69,12 @@ def parse_args():
69
  action="store_true",
70
  help="Use 8-bit Adam optimizer from bitsandbytes",
71
  )
 
 
 
 
 
 
72
 
73
  return parser.parse_args()
74
 
@@ -158,6 +164,7 @@ def main():
158
  log_samples=args.log_samples,
159
  last_per_updates=args.last_per_updates,
160
  bnb_optimizer=args.bnb_optimizer,
 
161
  )
162
 
163
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
 
69
  action="store_true",
70
  help="Use 8-bit Adam optimizer from bitsandbytes",
71
  )
72
+ parser.add_argument(
73
+ "--keep_last_n_checkpoints",
74
+ type=int,
75
+ default=-1,
76
+ help="-1 (default) to keep all checkpoints, 0 to not save intermediate checkpoints, positive N to keep last N checkpoints",
77
+ )
78
 
79
  return parser.parse_args()
80
 
 
164
  log_samples=args.log_samples,
165
  last_per_updates=args.last_per_updates,
166
  bnb_optimizer=args.bnb_optimizer,
167
+ keep_last_n_checkpoints=args.keep_last_n_checkpoints,
168
  )
169
 
170
  train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -70,6 +70,7 @@ def save_settings(
70
  mixed_precision,
71
  logger,
72
  ch_8bit_adam,
 
73
  ):
74
  path_project = os.path.join(path_project_ckpts, project_name)
75
  os.makedirs(path_project, exist_ok=True)
@@ -94,6 +95,7 @@ def save_settings(
94
  "mixed_precision": mixed_precision,
95
  "logger": logger,
96
  "bnb_optimizer": ch_8bit_adam,
 
97
  }
98
  with open(file_setting, "w") as f:
99
  json.dump(settings, f, indent=4)
@@ -126,6 +128,7 @@ def load_settings(project_name):
126
  "mixed_precision": "none",
127
  "logger": "wandb",
128
  "bnb_optimizer": False,
 
129
  }
130
  return (
131
  settings["exp_name"],
@@ -146,6 +149,7 @@ def load_settings(project_name):
146
  settings["mixed_precision"],
147
  settings["logger"],
148
  settings["bnb_optimizer"],
 
149
  )
150
 
151
  with open(file_setting, "r") as f:
@@ -154,6 +158,8 @@ def load_settings(project_name):
154
  settings["logger"] = "wandb"
155
  if "bnb_optimizer" not in settings:
156
  settings["bnb_optimizer"] = False
 
 
157
  if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
158
  settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
159
  return (
@@ -175,6 +181,7 @@ def load_settings(project_name):
175
  settings["mixed_precision"],
176
  settings["logger"],
177
  settings["bnb_optimizer"],
 
178
  )
179
 
180
 
@@ -390,6 +397,7 @@ def start_training(
390
  stream=False,
391
  logger="wandb",
392
  ch_8bit_adam=False,
 
393
  ):
394
  global training_process, tts_api, stop_signal
395
 
@@ -451,7 +459,8 @@ def start_training(
451
  f"--num_warmup_updates {num_warmup_updates} "
452
  f"--save_per_updates {save_per_updates} "
453
  f"--last_per_updates {last_per_updates} "
454
- f"--dataset_name {dataset_name}"
 
455
  )
456
 
457
  if finetune:
@@ -492,6 +501,7 @@ def start_training(
492
  mixed_precision,
493
  logger,
494
  ch_8bit_adam,
 
495
  )
496
 
497
  try:
@@ -1564,6 +1574,13 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1564
  with gr.Row():
1565
  save_per_updates = gr.Number(label="Save per Updates", value=300)
1566
  last_per_updates = gr.Number(label="Last per Updates", value=100)
 
 
 
 
 
 
 
1567
 
1568
  with gr.Row():
1569
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
@@ -1592,6 +1609,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1592
  mixed_precisionv,
1593
  cd_loggerv,
1594
  ch_8bit_adamv,
 
1595
  ) = load_settings(projects_selelect)
1596
  exp_name.value = exp_namev
1597
  learning_rate.value = learning_ratev
@@ -1611,6 +1629,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1611
  mixed_precision.value = mixed_precisionv
1612
  cd_logger.value = cd_loggerv
1613
  ch_8bit_adam.value = ch_8bit_adamv
 
1614
 
1615
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1616
  txt_info_train = gr.Text(label="Info", value="")
@@ -1670,6 +1689,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1670
  ch_stream,
1671
  cd_logger,
1672
  ch_8bit_adam,
 
1673
  ],
1674
  outputs=[txt_info_train, start_button, stop_button],
1675
  )
 
70
  mixed_precision,
71
  logger,
72
  ch_8bit_adam,
73
+ keep_last_n_checkpoints,
74
  ):
75
  path_project = os.path.join(path_project_ckpts, project_name)
76
  os.makedirs(path_project, exist_ok=True)
 
95
  "mixed_precision": mixed_precision,
96
  "logger": logger,
97
  "bnb_optimizer": ch_8bit_adam,
98
+ "keep_last_n_checkpoints": keep_last_n_checkpoints,
99
  }
100
  with open(file_setting, "w") as f:
101
  json.dump(settings, f, indent=4)
 
128
  "mixed_precision": "none",
129
  "logger": "wandb",
130
  "bnb_optimizer": False,
131
+ "keep_last_n_checkpoints": -1, # Default to keep all checkpoints
132
  }
133
  return (
134
  settings["exp_name"],
 
149
  settings["mixed_precision"],
150
  settings["logger"],
151
  settings["bnb_optimizer"],
152
+ settings["keep_last_n_checkpoints"],
153
  )
154
 
155
  with open(file_setting, "r") as f:
 
158
  settings["logger"] = "wandb"
159
  if "bnb_optimizer" not in settings:
160
  settings["bnb_optimizer"] = False
161
+ if "keep_last_n_checkpoints" not in settings:
162
+ settings["keep_last_n_checkpoints"] = -1 # Default to keep all checkpoints
163
  if "last_per_updates" not in settings: # patch for backward compatibility, with before f992c4e
164
  settings["last_per_updates"] = settings["last_per_steps"] // settings["grad_accumulation_steps"]
165
  return (
 
181
  settings["mixed_precision"],
182
  settings["logger"],
183
  settings["bnb_optimizer"],
184
+ settings["keep_last_n_checkpoints"],
185
  )
186
 
187
 
 
397
  stream=False,
398
  logger="wandb",
399
  ch_8bit_adam=False,
400
+ keep_last_n_checkpoints=-1,
401
  ):
402
  global training_process, tts_api, stop_signal
403
 
 
459
  f"--num_warmup_updates {num_warmup_updates} "
460
  f"--save_per_updates {save_per_updates} "
461
  f"--last_per_updates {last_per_updates} "
462
+ f"--dataset_name {dataset_name} "
463
+ f"--keep_last_n_checkpoints {keep_last_n_checkpoints}"
464
  )
465
 
466
  if finetune:
 
501
  mixed_precision,
502
  logger,
503
  ch_8bit_adam,
504
+ keep_last_n_checkpoints,
505
  )
506
 
507
  try:
 
1574
  with gr.Row():
1575
  save_per_updates = gr.Number(label="Save per Updates", value=300)
1576
  last_per_updates = gr.Number(label="Last per Updates", value=100)
1577
+ keep_last_n_checkpoints = gr.Number(
1578
+ label="Keep Last N Checkpoints",
1579
+ value=-1,
1580
+ step=1,
1581
+ precision=0,
1582
+ info="-1: Keep all checkpoints, 0: Only save final model_last.pt, N>0: Keep last N checkpoints",
1583
+ )
1584
 
1585
  with gr.Row():
1586
  ch_8bit_adam = gr.Checkbox(label="Use 8-bit Adam optimizer")
 
1609
  mixed_precisionv,
1610
  cd_loggerv,
1611
  ch_8bit_adamv,
1612
+ keep_last_n_checkpointsv,
1613
  ) = load_settings(projects_selelect)
1614
  exp_name.value = exp_namev
1615
  learning_rate.value = learning_ratev
 
1629
  mixed_precision.value = mixed_precisionv
1630
  cd_logger.value = cd_loggerv
1631
  ch_8bit_adam.value = ch_8bit_adamv
1632
+ keep_last_n_checkpoints.value = keep_last_n_checkpointsv
1633
 
1634
  ch_stream = gr.Checkbox(label="Stream Output Experiment", value=True)
1635
  txt_info_train = gr.Text(label="Info", value="")
 
1689
  ch_stream,
1690
  cd_logger,
1691
  ch_8bit_adam,
1692
+ keep_last_n_checkpoints,
1693
  ],
1694
  outputs=[txt_info_train, start_button, stop_button],
1695
  )
src/f5_tts/train/train.py CHANGED
@@ -61,6 +61,7 @@ def main(cfg):
61
  mel_spec_type=mel_spec_type,
62
  is_local_vocoder=cfg.model.vocoder.is_local,
63
  local_vocoder_path=cfg.model.vocoder.local_path,
 
64
  )
65
 
66
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
 
61
  mel_spec_type=mel_spec_type,
62
  is_local_vocoder=cfg.model.vocoder.is_local,
63
  local_vocoder_path=cfg.model.vocoder.local_path,
64
+ keep_last_n_checkpoints=getattr(cfg.ckpts, "keep_last_n_checkpoints", None),
65
  )
66
 
67
  train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)