Use weighted list reward functions
Browse filesseparate the calculation of advantage from the calculation of rewards
- src/retool_trainer.py +89 -12
src/retool_trainer.py
CHANGED
@@ -7,7 +7,7 @@ import datasets
|
|
7 |
import torch
|
8 |
import torch.utils.data
|
9 |
import transformers
|
10 |
-
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
11 |
from datasets import Dataset, IterableDataset
|
12 |
from packaging import version
|
13 |
from torch import nn
|
@@ -163,19 +163,96 @@ class ReToolTrainer(Trainer): # Change this line
|
|
163 |
return self._check_equivalence(predicted, ground_truth)
|
164 |
return False
|
165 |
|
166 |
-
def
|
167 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
-
#
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
-
#
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
return advantages
|
181 |
|
|
|
7 |
import torch
|
8 |
import torch.utils.data
|
9 |
import transformers
|
10 |
+
#from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
11 |
from datasets import Dataset, IterableDataset
|
12 |
from packaging import version
|
13 |
from torch import nn
|
|
|
163 |
return self._check_equivalence(predicted, ground_truth)
|
164 |
return False
|
165 |
|
166 |
+
def _compute_rewards(self, inputs, prompts, completions, completion_ids_list=None):
|
167 |
+
"""Calculate rewards for completions and combine them according to weights."""
|
168 |
+
device = self.device # Your device might be set differently
|
169 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
170 |
+
|
171 |
+
# Extract additional arguments from inputs if needed
|
172 |
+
reward_kwargs = {}
|
173 |
+
if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
|
174 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
|
175 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
176 |
|
177 |
+
# Add correct_answers to kwargs if present (common in math reasoning tasks)
|
178 |
+
if "correct_answers" in reward_kwargs:
|
179 |
+
reward_kwargs["solution"] = reward_kwargs["correct_answers"] # Alias for compatibility
|
180 |
+
|
181 |
+
# Calculate rewards for each function with non-zero weight
|
182 |
+
for i, (reward_func, func_name) in enumerate(zip(self.reward_funcs, self.reward_func_names)):
|
183 |
+
# Skip computation if weight is zero
|
184 |
+
if abs(self.reward_weights[i].item()) < 1e-6:
|
185 |
+
rewards_per_func[:, i] = float('nan')
|
186 |
+
if self.verbose:
|
187 |
+
print(f"Skipping reward '{func_name}' (zero weight)")
|
188 |
+
continue
|
189 |
+
|
190 |
+
# Calculate reward
|
191 |
+
try:
|
192 |
+
# Call the reward function with appropriate arguments
|
193 |
+
rewards = reward_func(
|
194 |
+
prompts=prompts,
|
195 |
+
completions=completions,
|
196 |
+
completion_ids=completion_ids_list if completion_ids_list is not None else None,
|
197 |
+
**reward_kwargs
|
198 |
+
)
|
199 |
+
|
200 |
+
# Convert None values to NaN and ensure it's a tensor
|
201 |
+
rewards = [r if r is not None else float('nan') for r in rewards]
|
202 |
+
rewards_per_func[:, i] = torch.tensor(rewards, dtype=torch.float32, device=device)
|
203 |
+
|
204 |
+
# Log reward statistics if verbose
|
205 |
+
if self.verbose:
|
206 |
+
valid_rewards = [r for r in rewards if not (r is None or (isinstance(r, float) and math.isnan(r)))]
|
207 |
+
if valid_rewards:
|
208 |
+
print(f"Reward '{func_name}': min={min(valid_rewards):.4f}, max={max(valid_rewards):.4f}, "
|
209 |
+
f"mean={sum(valid_rewards)/len(valid_rewards):.4f}")
|
210 |
+
except Exception as e:
|
211 |
+
print(f"Error in reward function '{func_name}': {e}")
|
212 |
+
rewards_per_func[:, i] = float('nan')
|
213 |
+
|
214 |
+
# Combine rewards using weights
|
215 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
|
216 |
|
217 |
+
# Convert to list for easier handling
|
218 |
+
final_rewards = rewards.cpu().tolist()
|
219 |
+
|
220 |
+
return final_rewards
|
221 |
+
|
222 |
+
|
223 |
+
def compute_rewards_and_advantages(self, inputs, prompts, completions, completion_ids_list=None):
|
224 |
+
"""Calculate rewards and compute advantages based on those rewards."""
|
225 |
+
# First calculate rewards
|
226 |
+
rewards = self.compute_rewards(inputs, prompts, completions, completion_ids_list)
|
227 |
+
|
228 |
+
# Convert to tensor if not already
|
229 |
+
if not isinstance(rewards, torch.Tensor):
|
230 |
+
rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
|
231 |
+
|
232 |
+
# For now, simple advantage calculation
|
233 |
+
advantages = rewards.clone() # Simple case: advantages = rewards
|
234 |
+
|
235 |
+
# If later I want to implement GRPO-style advantage calculation:
|
236 |
+
if self.use_grouped_advantages:
|
237 |
+
# Reshape rewards into groups (assuming self.num_generations is set)
|
238 |
+
grouped_rewards = rewards.view(-1, self.num_generations)
|
239 |
+
|
240 |
+
# Calculate statistics per group
|
241 |
+
mean_grouped_rewards = grouped_rewards.mean(dim=1)
|
242 |
+
std_grouped_rewards = grouped_rewards.std(dim=1)
|
243 |
+
|
244 |
+
# Expand means and stds to match original shape
|
245 |
+
mean_expanded = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
246 |
+
std_expanded = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
247 |
+
|
248 |
+
# Compute advantages: reward - baseline
|
249 |
+
advantages = rewards - mean_expanded
|
250 |
+
|
251 |
+
# Optionally normalize advantages
|
252 |
+
if self.normalize_advantages:
|
253 |
+
# Avoid division by zero
|
254 |
+
std_expanded = torch.clamp(std_expanded, min=1e-8)
|
255 |
+
advantages = advantages / std_expanded
|
256 |
|
257 |
return advantages
|
258 |
|