|
import subprocess |
|
from typing import List |
|
|
|
from transformers import TrainerCallback |
|
from transformers.trainer_callback import TrainerControl, TrainerState |
|
from transformers.training_args import TrainingArguments |
|
|
|
class CurriculumLearningCallback(TrainerCallback): |
|
def __init__(self): |
|
self.current_stage = "format_stage" |
|
self.stages = { |
|
"format_stage": { |
|
"reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.1, |
|
"steps": 1000 |
|
}, |
|
"code_execution_stage": { |
|
"reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.05, |
|
"steps": 2000 |
|
}, |
|
"accuracy_stage": { |
|
"reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.01, |
|
"steps": 3000 |
|
}, |
|
"refinement_stage": { |
|
"reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1, |
|
"length": 0.1, "code_ratio": 0.05, "code_timing": 0.05}, |
|
"beta": 0.03, |
|
"steps": 5000 |
|
} |
|
} |
|
|
|
self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values()) |
|
self.stage_transitions = self._calculate_stage_transitions() |
|
|
|
def _calculate_stage_transitions(self): |
|
"""Calculate at which step each stage transition occurs.""" |
|
transitions = {} |
|
current_step = 0 |
|
for stage, config in self.stages.items(): |
|
current_step += config["steps"] |
|
transitions[stage] = current_step |
|
return transitions |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
"""Update reward weights based on current training stage.""" |
|
trainer = kwargs.get('trainer') |
|
if trainer is None: |
|
return |
|
|
|
|
|
current_step = state.global_step |
|
|
|
|
|
previous_stage = self.current_stage |
|
for stage, transition_step in self.stage_transitions.items(): |
|
if current_step <= transition_step: |
|
self.current_stage = stage |
|
break |
|
|
|
|
|
if previous_stage != self.current_stage: |
|
print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}") |
|
|
|
|
|
stage_weights = self.stages[self.current_stage]["reward_weights"] |
|
|
|
|
|
|
|
for i, func_name in enumerate(trainer.reward_func_names): |
|
if func_name in stage_weights: |
|
trainer.reward_weights[i] = stage_weights[func_name] |
|
|
|
|
|
|
|
class CurriculumLearningCallback(TrainerCallback): |
|
"""A callback to implement curriculum learning stages during training.""" |
|
def __init__(self, debug=False): |
|
self.debug = debug |
|
self.current_stage = "format_stage" |
|
self.stages = { |
|
"format_stage": { |
|
"reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.1, |
|
"steps": 1000 |
|
}, |
|
"code_execution_stage": { |
|
"reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.05, |
|
"steps": 2000 |
|
}, |
|
"accuracy_stage": { |
|
"reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0, |
|
"length": 0.0, "code_ratio": 0.0, "code_timing": 0.0}, |
|
"beta": 0.01, |
|
"steps": 3000 |
|
}, |
|
"refinement_stage": { |
|
"reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1, |
|
"length": 0.1, "code_ratio": 0.05, "code_timing": 0.05}, |
|
"beta": 0.03, |
|
"steps": 5000 |
|
} |
|
} |
|
self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values()) |
|
self.stage_transitions = self._calculate_stage_transitions() |
|
|
|
print(f"Curriculum learning initialized with {len(self.stages)} stages:") |
|
for stage, end_step in self.stage_transitions.items(): |
|
print(f" {stage}: ends at step {end_step}") |
|
|
|
def _calculate_stage_transitions(self): |
|
"""Calculate at which step each stage transition occurs.""" |
|
transitions = {} |
|
current_step = 0 |
|
for stage, config in self.stages.items(): |
|
current_step += config["steps"] |
|
transitions[stage] = current_step |
|
return transitions |
|
|
|
def on_train_begin(self, args, state, control, **kwargs): |
|
"""Initialize reward weights and beta at the start of training.""" |
|
trainer = kwargs.get('trainer') |
|
if trainer is None: |
|
return |
|
|
|
|
|
first_stage = list(self.stages.keys())[0] |
|
stage_config = self.stages[first_stage] |
|
|
|
|
|
if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"): |
|
for i, func_name in enumerate(trainer.reward_func_names): |
|
if func_name in stage_config["reward_weights"]: |
|
trainer.reward_weights[i] = stage_config["reward_weights"][func_name] |
|
if self.debug: |
|
print(f"Setting initial weight for {func_name}: {trainer.reward_weights[i]}") |
|
else: |
|
print("Warning: Trainer doesn't have reward_weights or reward_func_names attributes") |
|
|
|
|
|
if hasattr(trainer, "beta"): |
|
trainer.beta = stage_config.get("beta", 0.1) |
|
if self.debug: |
|
print(f"Setting initial beta: {trainer.beta}") |
|
else: |
|
print("Warning: Trainer doesn't have a beta attribute") |
|
|
|
def on_step_end(self, args, state, control, **kwargs): |
|
"""Update reward weights and beta based on current training stage.""" |
|
trainer = kwargs.get('trainer') |
|
if trainer is None: |
|
return |
|
|
|
|
|
current_step = state.global_step |
|
|
|
|
|
previous_stage = self.current_stage |
|
for stage, transition_step in sorted(self.stage_transitions.items()): |
|
if current_step <= transition_step: |
|
self.current_stage = stage |
|
break |
|
|
|
|
|
if previous_stage != self.current_stage: |
|
print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}") |
|
|
|
|
|
stage_config = self.stages[self.current_stage] |
|
|
|
|
|
if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"): |
|
for i, func_name in enumerate(trainer.reward_func_names): |
|
if func_name in stage_config["reward_weights"]: |
|
new_weight = stage_config["reward_weights"][func_name] |
|
if trainer.reward_weights[i] != new_weight: |
|
trainer.reward_weights[i] = new_weight |
|
if self.debug: |
|
print(f"Updated weight for {func_name}: {new_weight}") |
|
|
|
|
|
if hasattr(trainer, "beta"): |
|
new_beta = stage_config.get("beta", 0.1) |
|
if trainer.beta != new_beta: |
|
trainer.beta = new_beta |
|
if self.debug: |
|
print(f"Updated beta: {new_beta}") |