hcsolakoglu commited on
Commit
4f95ee1
·
1 Parent(s): 488d746

Fix for the checkpoint dropdown menu

Browse files
Files changed (1) hide show
  1. src/f5_tts/train/finetune_gradio.py +15 -5
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1261,12 +1261,22 @@ def get_checkpoints_project(project_name, is_gradio=True):
1261
 
1262
  if os.path.isdir(path_project_ckpts):
1263
  files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1264
- files_checkpoints = sorted(
1265
- files_checkpoints,
1266
- key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1267
- if os.path.basename(x) != "model_last.pt"
1268
- else float("inf"),
 
 
 
 
 
 
 
1269
  )
 
 
 
1270
  else:
1271
  files_checkpoints = []
1272
 
 
1261
 
1262
  if os.path.isdir(path_project_ckpts):
1263
  files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
1264
+ # Separate pretrained and regular checkpoints
1265
+ pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
1266
+ regular_checkpoints = [
1267
+ f
1268
+ for f in files_checkpoints
1269
+ if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
1270
+ ]
1271
+ last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
1272
+
1273
+ # Sort regular checkpoints by number
1274
+ regular_checkpoints = sorted(
1275
+ regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
1276
  )
1277
+
1278
+ # Combine in order: pretrained, regular, last
1279
+ files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
1280
  else:
1281
  files_checkpoints = []
1282