Yushen CHEN commited on
Commit
c2cf31e
·
unverified ·
2 Parent(s): 46266f1 2d27d2c

Merge pull request #729 from hcsolakoglu/fix-ckpt-rotation

Browse files

Exclude pretrained models from the checkpoint rotation logic

src/f5_tts/model/trainer.py CHANGED
@@ -160,10 +160,14 @@ class Trainer:
160
  return
161
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
162
  if self.keep_last_n_checkpoints > 0:
 
163
  checkpoints = [
164
  f
165
  for f in os.listdir(self.checkpoint_path)
166
- if f.startswith("model_") and f.endswith(".pt") and f != "model_last.pt"
 
 
 
167
  ]
168
  checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
169
  while len(checkpoints) > self.keep_last_n_checkpoints:
@@ -183,10 +187,24 @@ class Trainer:
183
  if "model_last.pt" in os.listdir(self.checkpoint_path):
184
  latest_checkpoint = "model_last.pt"
185
  else:
186
- latest_checkpoint = sorted(
187
- [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
188
- key=lambda x: int("".join(filter(str.isdigit, x))),
189
- )[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
191
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
192
 
 
160
  return
161
  self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{update}.pt")
162
  if self.keep_last_n_checkpoints > 0:
163
+ # Updated logic to exclude pretrained model from rotation
164
  checkpoints = [
165
  f
166
  for f in os.listdir(self.checkpoint_path)
167
+ if f.startswith("model_")
168
+ and not f.startswith("pretrained_") # Exclude pretrained models
169
+ and f.endswith(".pt")
170
+ and f != "model_last.pt"
171
  ]
172
  checkpoints.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
173
  while len(checkpoints) > self.keep_last_n_checkpoints:
 
187
  if "model_last.pt" in os.listdir(self.checkpoint_path):
188
  latest_checkpoint = "model_last.pt"
189
  else:
190
+ # Updated to consider pretrained models for loading but prioritize training checkpoints
191
+ all_checkpoints = [
192
+ f
193
+ for f in os.listdir(self.checkpoint_path)
194
+ if (f.startswith("model_") or f.startswith("pretrained_")) and f.endswith(".pt")
195
+ ]
196
+
197
+ # First try to find regular training checkpoints
198
+ training_checkpoints = [f for f in all_checkpoints if f.startswith("model_") and f != "model_last.pt"]
199
+ if training_checkpoints:
200
+ latest_checkpoint = sorted(
201
+ training_checkpoints,
202
+ key=lambda x: int("".join(filter(str.isdigit, x))),
203
+ )[-1]
204
+ else:
205
+ # If no training checkpoints, use pretrained model
206
+ latest_checkpoint = next(f for f in all_checkpoints if f.startswith("pretrained_"))
207
+
208
  # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
209
  checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
210
 
src/f5_tts/train/finetune_cli.py CHANGED
@@ -111,7 +111,8 @@ def main():
111
  if not os.path.isdir(checkpoint_path):
112
  os.makedirs(checkpoint_path, exist_ok=True)
113
 
114
- file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
 
115
  if not os.path.isfile(file_checkpoint):
116
  shutil.copy2(ckpt_path, file_checkpoint)
117
  print("copy checkpoint for finetune")
 
111
  if not os.path.isdir(checkpoint_path):
112
  os.makedirs(checkpoint_path, exist_ok=True)
113
 
114
+ # Change: Add 'pretrained_' prefix to copied model
115
+ file_checkpoint = os.path.join(checkpoint_path, "pretrained_" + os.path.basename(ckpt_path))
116
  if not os.path.isfile(file_checkpoint):
117
  shutil.copy2(ckpt_path, file_checkpoint)
118
  print("copy checkpoint for finetune")
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1099,7 +1099,9 @@ def vocab_extend(project_name, symbols, model_type):
1099
  dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1100
  new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1101
  os.makedirs(new_ckpt_path, exist_ok=True)
1102
- new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
 
 
1103
 
1104
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1105
 
 
1099
  dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
1100
  new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
1101
  os.makedirs(new_ckpt_path, exist_ok=True)
1102
+
1103
+ # Add pretrained_ prefix to model when copying for consistency with finetune_cli.py
1104
+ new_ckpt_file = os.path.join(new_ckpt_path, "pretrained_model_1200000.pt")
1105
 
1106
  size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
1107