Spaces:
Running
Running
Commit
·
69ef772
1
Parent(s):
d9b1955
fix the checkpoint recovery
Browse files
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
|
@@ -325,21 +325,8 @@ class SFTTrainer:
|
|
| 325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
| 326 |
if resume_from_checkpoint == "latest":
|
| 327 |
resume_from_checkpoint = -1
|
| 328 |
-
|
| 329 |
-
# Store the load result
|
| 330 |
-
load_successful = False
|
| 331 |
if resume_from_checkpoint is not None:
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
# If loading succeeded and we have a specific checkpoint path
|
| 335 |
-
if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
|
| 336 |
-
try:
|
| 337 |
-
step = int(resume_from_checkpoint.split("_")[-1])
|
| 338 |
-
self.state.train_state.step = step
|
| 339 |
-
logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
|
| 340 |
-
except (ValueError, IndexError):
|
| 341 |
-
logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
|
| 342 |
-
|
| 343 |
|
| 344 |
def _train(self) -> None:
|
| 345 |
logger.info("Starting training")
|
|
|
|
| 325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
| 326 |
if resume_from_checkpoint == "latest":
|
| 327 |
resume_from_checkpoint = -1
|
|
|
|
|
|
|
|
|
|
| 328 |
if resume_from_checkpoint is not None:
|
| 329 |
+
self.checkpointer.load(resume_from_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
def _train(self) -> None:
|
| 332 |
logger.info("Starting training")
|
vms/ui/project/services/training.py
CHANGED
|
@@ -1097,6 +1097,11 @@ class TrainingService:
|
|
| 1097 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1098 |
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
| 1099 |
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
else:
|
| 1101 |
logger.warning("No checkpoints found for recovery")
|
| 1102 |
# Set buttons for no active training
|
|
|
|
| 1097 |
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split("_")[-1]))
|
| 1098 |
checkpoint_step = int(latest_checkpoint.name.split("_")[-1])
|
| 1099 |
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
| 1100 |
+
|
| 1101 |
+
# both options are valid, but imho it is easier to just return "latest"
|
| 1102 |
+
# under the hood Finetrainers will convert ("latest") to (-1)
|
| 1103 |
+
#latest_checkpoint = int(checkpoint_step)
|
| 1104 |
+
latest_checkpoint = "latest"
|
| 1105 |
else:
|
| 1106 |
logger.warning("No checkpoints found for recovery")
|
| 1107 |
# Set buttons for no active training
|