bird-of-paradise commited on
Commit
cfa2a65
·
verified ·
1 Parent(s): 2fc6f4d

first commit --curriculum callback

Browse files
Files changed (1) hide show
  1. src/utils/callbacks.py +195 -0
src/utils/callbacks.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ from typing import List
3
+
4
+ from transformers import TrainerCallback
5
+ from transformers.trainer_callback import TrainerControl, TrainerState
6
+ from transformers.training_args import TrainingArguments
7
+
8
+ class CurriculumLearningCallback(TrainerCallback):
9
+ def __init__(self):
10
+ self.current_stage = "format_stage"
11
+ self.stages = {
12
+ "format_stage": {
13
+ "reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0,
14
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
15
+ "beta": 0.1, # Higher KL - stay close to base model format
16
+ "steps": 1000
17
+ },
18
+ "code_execution_stage": {
19
+ "reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7,
20
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
21
+ "beta": 0.05, # Medium KL
22
+ "steps": 2000
23
+ },
24
+ "accuracy_stage": {
25
+ "reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0,
26
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
27
+ "beta": 0.01, # Very low KL - allow exploration
28
+ "steps": 3000
29
+ },
30
+ "refinement_stage": {
31
+ "reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1,
32
+ "length": 0.1, "code_ratio": 0.05, "code_timing": 0.05},
33
+ "beta": 0.03, # Medium-low KL - stabilize learning
34
+ "steps": 5000
35
+ }
36
+ }
37
+
38
+ self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values())
39
+ self.stage_transitions = self._calculate_stage_transitions()
40
+
41
+ def _calculate_stage_transitions(self):
42
+ """Calculate at which step each stage transition occurs."""
43
+ transitions = {}
44
+ current_step = 0
45
+ for stage, config in self.stages.items():
46
+ current_step += config["steps"]
47
+ transitions[stage] = current_step
48
+ return transitions
49
+
50
+ def on_step_end(self, args, state, control, **kwargs):
51
+ """Update reward weights based on current training stage."""
52
+ trainer = kwargs.get('trainer')
53
+ if trainer is None:
54
+ return
55
+
56
+ # Check if it's time to transition to the next stage
57
+ current_step = state.global_step
58
+
59
+ # Determine current stage
60
+ previous_stage = self.current_stage
61
+ for stage, transition_step in self.stage_transitions.items():
62
+ if current_step <= transition_step:
63
+ self.current_stage = stage
64
+ break
65
+
66
+ # If stage changed, update weights and log the transition
67
+ if previous_stage != self.current_stage:
68
+ print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}")
69
+
70
+ # Apply weights for current stage
71
+ stage_weights = self.stages[self.current_stage]["reward_weights"]
72
+
73
+ # Update trainer's reward weights
74
+ # This assumes the trainer has a reward_weights attribute
75
+ for i, func_name in enumerate(trainer.reward_func_names):
76
+ if func_name in stage_weights:
77
+ trainer.reward_weights[i] = stage_weights[func_name]
78
+
79
+
80
+
81
+ class CurriculumLearningCallback(TrainerCallback):
82
+ """A callback to implement curriculum learning stages during training."""
83
+ def __init__(self, debug=False):
84
+ self.debug = debug
85
+ self.current_stage = "format_stage"
86
+ self.stages = {
87
+ "format_stage": {
88
+ "reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0,
89
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
90
+ "beta": 0.1, # Higher KL - stay close to base model format
91
+ "steps": 1000
92
+ },
93
+ "code_execution_stage": {
94
+ "reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7,
95
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
96
+ "beta": 0.05, # Medium KL
97
+ "steps": 2000
98
+ },
99
+ "accuracy_stage": {
100
+ "reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0,
101
+ "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
102
+ "beta": 0.01, # Very low KL - allow exploration
103
+ "steps": 3000
104
+ },
105
+ "refinement_stage": {
106
+ "reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1,
107
+ "length": 0.1, "code_ratio": 0.05, "code_timing": 0.05},
108
+ "beta": 0.03, # Medium-low KL - stabilize learning
109
+ "steps": 5000
110
+ }
111
+ }
112
+ self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values())
113
+ self.stage_transitions = self._calculate_stage_transitions()
114
+
115
+ print(f"Curriculum learning initialized with {len(self.stages)} stages:")
116
+ for stage, end_step in self.stage_transitions.items():
117
+ print(f" {stage}: ends at step {end_step}")
118
+
119
+ def _calculate_stage_transitions(self):
120
+ """Calculate at which step each stage transition occurs."""
121
+ transitions = {}
122
+ current_step = 0
123
+ for stage, config in self.stages.items():
124
+ current_step += config["steps"]
125
+ transitions[stage] = current_step
126
+ return transitions
127
+
128
+ def on_train_begin(self, args, state, control, **kwargs):
129
+ """Initialize reward weights and beta at the start of training."""
130
+ trainer = kwargs.get('trainer')
131
+ if trainer is None:
132
+ return
133
+
134
+ # Set initial weights and beta from first stage
135
+ first_stage = list(self.stages.keys())[0]
136
+ stage_config = self.stages[first_stage]
137
+
138
+ # Update reward weights
139
+ if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"):
140
+ for i, func_name in enumerate(trainer.reward_func_names):
141
+ if func_name in stage_config["reward_weights"]:
142
+ trainer.reward_weights[i] = stage_config["reward_weights"][func_name]
143
+ if self.debug:
144
+ print(f"Setting initial weight for {func_name}: {trainer.reward_weights[i]}")
145
+ else:
146
+ print("Warning: Trainer doesn't have reward_weights or reward_func_names attributes")
147
+
148
+ # Update beta (KL coefficient)
149
+ if hasattr(trainer, "beta"):
150
+ trainer.beta = stage_config.get("beta", 0.1)
151
+ if self.debug:
152
+ print(f"Setting initial beta: {trainer.beta}")
153
+ else:
154
+ print("Warning: Trainer doesn't have a beta attribute")
155
+
156
+ def on_step_end(self, args, state, control, **kwargs):
157
+ """Update reward weights and beta based on current training stage."""
158
+ trainer = kwargs.get('trainer')
159
+ if trainer is None:
160
+ return
161
+
162
+ # Check if it's time to transition to the next stage
163
+ current_step = state.global_step
164
+
165
+ # Determine current stage
166
+ previous_stage = self.current_stage
167
+ for stage, transition_step in sorted(self.stage_transitions.items()):
168
+ if current_step <= transition_step:
169
+ self.current_stage = stage
170
+ break
171
+
172
+ # If stage changed, update weights and log the transition
173
+ if previous_stage != self.current_stage:
174
+ print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}")
175
+
176
+ # Get config for current stage
177
+ stage_config = self.stages[self.current_stage]
178
+
179
+ # Update reward weights
180
+ if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"):
181
+ for i, func_name in enumerate(trainer.reward_func_names):
182
+ if func_name in stage_config["reward_weights"]:
183
+ new_weight = stage_config["reward_weights"][func_name]
184
+ if trainer.reward_weights[i] != new_weight:
185
+ trainer.reward_weights[i] = new_weight
186
+ if self.debug:
187
+ print(f"Updated weight for {func_name}: {new_weight}")
188
+
189
+ # Update beta (KL coefficient)
190
+ if hasattr(trainer, "beta"):
191
+ new_beta = stage_config.get("beta", 0.1)
192
+ if trainer.beta != new_beta:
193
+ trainer.beta = new_beta
194
+ if self.debug:
195
+ print(f"Updated beta: {new_beta}")