Commit
·
4f95ee1
1
Parent(s):
488d746
Fix for the checkpoint dropdown menu
Browse files
src/f5_tts/train/finetune_gradio.py
CHANGED
@@ -1261,12 +1261,22 @@ def get_checkpoints_project(project_name, is_gradio=True):
|
|
1261 |
|
1262 |
if os.path.isdir(path_project_ckpts):
|
1263 |
files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
|
1264 |
-
|
1265 |
-
|
1266 |
-
|
1267 |
-
|
1268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1269 |
)
|
|
|
|
|
|
|
1270 |
else:
|
1271 |
files_checkpoints = []
|
1272 |
|
|
|
1261 |
|
1262 |
if os.path.isdir(path_project_ckpts):
|
1263 |
files_checkpoints = glob(os.path.join(path_project_ckpts, project_name, "*.pt"))
|
1264 |
+
# Separate pretrained and regular checkpoints
|
1265 |
+
pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
|
1266 |
+
regular_checkpoints = [
|
1267 |
+
f
|
1268 |
+
for f in files_checkpoints
|
1269 |
+
if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
|
1270 |
+
]
|
1271 |
+
last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
|
1272 |
+
|
1273 |
+
# Sort regular checkpoints by number
|
1274 |
+
regular_checkpoints = sorted(
|
1275 |
+
regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
|
1276 |
)
|
1277 |
+
|
1278 |
+
# Combine in order: pretrained, regular, last
|
1279 |
+
files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
|
1280 |
else:
|
1281 |
files_checkpoints = []
|
1282 |
|