bird-of-paradise commited on
Commit
e9196fe
·
verified ·
1 Parent(s): f757722

Use weighted list reward functions

Browse files

separate the calculation of advantage from the calculation of rewards

Files changed (1) hide show
  1. src/retool_trainer.py +89 -12
src/retool_trainer.py CHANGED
@@ -7,7 +7,7 @@ import datasets
7
  import torch
8
  import torch.utils.data
9
  import transformers
10
- from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
11
  from datasets import Dataset, IterableDataset
12
  from packaging import version
13
  from torch import nn
@@ -163,19 +163,96 @@ class ReToolTrainer(Trainer): # Change this line
163
  return self._check_equivalence(predicted, ground_truth)
164
  return False
165
 
166
- def _compute_rewards_and_advantages(self, completions_text, ground_truths, device):
167
- """Simplified reward and advantage computation for ReTool."""
 
 
 
 
 
 
 
 
168
 
169
- # Compute binary rewards
170
- rewards = []
171
- for completion_text, ground_truth in zip(completions_text, ground_truths):
172
- if self._is_correct_answer(completion_text, ground_truth):
173
- rewards.append(1.0)
174
- else:
175
- rewards.append(-1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- # For now: advantages = rewards (skip group normalization)
178
- advantages = torch.tensor(rewards, dtype=torch.float32, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  return advantages
181
 
 
7
  import torch
8
  import torch.utils.data
9
  import transformers
10
+ #from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
11
  from datasets import Dataset, IterableDataset
12
  from packaging import version
13
  from torch import nn
 
163
  return self._check_equivalence(predicted, ground_truth)
164
  return False
165
 
166
+ def _compute_rewards(self, inputs, prompts, completions, completion_ids_list=None):
167
+ """Calculate rewards for completions and combine them according to weights."""
168
+ device = self.device # Your device might be set differently
169
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
170
+
171
+ # Extract additional arguments from inputs if needed
172
+ reward_kwargs = {}
173
+ if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
174
+ keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]]
175
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
176
 
177
+ # Add correct_answers to kwargs if present (common in math reasoning tasks)
178
+ if "correct_answers" in reward_kwargs:
179
+ reward_kwargs["solution"] = reward_kwargs["correct_answers"] # Alias for compatibility
180
+
181
+ # Calculate rewards for each function with non-zero weight
182
+ for i, (reward_func, func_name) in enumerate(zip(self.reward_funcs, self.reward_func_names)):
183
+ # Skip computation if weight is zero
184
+ if abs(self.reward_weights[i].item()) < 1e-6:
185
+ rewards_per_func[:, i] = float('nan')
186
+ if self.verbose:
187
+ print(f"Skipping reward '{func_name}' (zero weight)")
188
+ continue
189
+
190
+ # Calculate reward
191
+ try:
192
+ # Call the reward function with appropriate arguments
193
+ rewards = reward_func(
194
+ prompts=prompts,
195
+ completions=completions,
196
+ completion_ids=completion_ids_list if completion_ids_list is not None else None,
197
+ **reward_kwargs
198
+ )
199
+
200
+ # Convert None values to NaN and ensure it's a tensor
201
+ rewards = [r if r is not None else float('nan') for r in rewards]
202
+ rewards_per_func[:, i] = torch.tensor(rewards, dtype=torch.float32, device=device)
203
+
204
+ # Log reward statistics if verbose
205
+ if self.verbose:
206
+ valid_rewards = [r for r in rewards if not (r is None or (isinstance(r, float) and math.isnan(r)))]
207
+ if valid_rewards:
208
+ print(f"Reward '{func_name}': min={min(valid_rewards):.4f}, max={max(valid_rewards):.4f}, "
209
+ f"mean={sum(valid_rewards)/len(valid_rewards):.4f}")
210
+ except Exception as e:
211
+ print(f"Error in reward function '{func_name}': {e}")
212
+ rewards_per_func[:, i] = float('nan')
213
+
214
+ # Combine rewards using weights
215
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
216
 
217
+ # Convert to list for easier handling
218
+ final_rewards = rewards.cpu().tolist()
219
+
220
+ return final_rewards
221
+
222
+
223
+ def compute_rewards_and_advantages(self, inputs, prompts, completions, completion_ids_list=None):
224
+ """Calculate rewards and compute advantages based on those rewards."""
225
+ # First calculate rewards
226
+ rewards = self.compute_rewards(inputs, prompts, completions, completion_ids_list)
227
+
228
+ # Convert to tensor if not already
229
+ if not isinstance(rewards, torch.Tensor):
230
+ rewards = torch.tensor(rewards, dtype=torch.float32, device=self.device)
231
+
232
+ # For now, simple advantage calculation
233
+ advantages = rewards.clone() # Simple case: advantages = rewards
234
+
235
+ # If later I want to implement GRPO-style advantage calculation:
236
+ if self.use_grouped_advantages:
237
+ # Reshape rewards into groups (assuming self.num_generations is set)
238
+ grouped_rewards = rewards.view(-1, self.num_generations)
239
+
240
+ # Calculate statistics per group
241
+ mean_grouped_rewards = grouped_rewards.mean(dim=1)
242
+ std_grouped_rewards = grouped_rewards.std(dim=1)
243
+
244
+ # Expand means and stds to match original shape
245
+ mean_expanded = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
246
+ std_expanded = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
247
+
248
+ # Compute advantages: reward - baseline
249
+ advantages = rewards - mean_expanded
250
+
251
+ # Optionally normalize advantages
252
+ if self.normalize_advantages:
253
+ # Avoid division by zero
254
+ std_expanded = torch.clamp(std_expanded, min=1e-8)
255
+ advantages = advantages / std_expanded
256
 
257
  return advantages
258