bird-of-paradise commited on
Commit
2fc6f4d
·
verified ·
1 Parent(s): e9196fe

Add reward functions and registry

Browse files
Files changed (1) hide show
  1. src/rewards.py +156 -0
src/rewards.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Dict, Optional, Callable
3
+ import numpy as np
4
+
5
+ def format_reward(completions, **kwargs):
6
+ """Reward function that checks if the code is enclosed within <code> and </code> tags,
7
+ and the final answer is enclosed within <answer> and </answer> tags."""
8
+ pattern = r".*?<code>.*?</code>.*?<answer>.*?</answer>.*?"
9
+ completion_contents = [completion[0]["content"] for completion in completions]
10
+ matches = [re.search(pattern, content, re.DOTALL | re.MULTILINE) is not None for content in completion_contents]
11
+ return [1.0 if match else 0.0 for match in matches]
12
+
13
+ def accuracy_reward(completions: list[list[dict[str, str]]], correct_answers: list[str], **kwargs) -> list[Optional[float]]:
14
+ """Reward function that checks if the completion's answer matches the ground truth."""
15
+ contents = [completion[0]["content"] for completion in completions]
16
+ rewards = []
17
+
18
+ for content, correct_answer in zip(contents, correct_answers):
19
+ # Extract answer from the completion using regex
20
+ answer_match = re.search(r'<answer>\\boxed{(.*?)}</answer>', content, re.DOTALL)
21
+ if not answer_match:
22
+ rewards.append(0.0)
23
+ continue
24
+
25
+ extracted_answer = answer_match.group(1).strip()
26
+
27
+ # Check if the extracted answer matches the correct answer
28
+ # You might need a more sophisticated comparison for mathematical expressions
29
+ if extracted_answer == correct_answer:
30
+ rewards.append(1.0)
31
+ else:
32
+ rewards.append(0.0)
33
+
34
+ return rewards
35
+
36
+ def code_execution_reward(completions, **kwargs):
37
+ """Reward function that checks if the code execution was successful."""
38
+ completion_contents = [completion[0]["content"] for completion in completions]
39
+ # Check for error patterns in interpreter output
40
+ error_patterns = [
41
+ r'<interpreter>.*?Error.*?</interpreter>',
42
+ r'<interpreter>.*?Exception.*?</interpreter>',
43
+ r'<interpreter>.*?Traceback.*?</interpreter>'
44
+ ]
45
+
46
+ rewards = []
47
+ for content in completion_contents:
48
+ # Find all code-interpreter pairs
49
+ code_blocks = re.findall(r'<code>.*?</code>\s*<interpreter>(.*?)</interpreter>', content, re.DOTALL)
50
+ if not code_blocks:
51
+ rewards.append(0.0)
52
+ continue
53
+
54
+ # Check each interpreter output for errors
55
+ error_count = 0
56
+ for interpreter_output in code_blocks:
57
+ has_error = any(re.search(pattern, interpreter_output, re.DOTALL) for pattern in error_patterns)
58
+ if has_error:
59
+ error_count += 1
60
+
61
+ # Calculate success rate
62
+ if len(code_blocks) == 0:
63
+ rewards.append(0.0)
64
+ else:
65
+ success_rate = 1.0 - (error_count / len(code_blocks))
66
+ rewards.append(success_rate)
67
+
68
+ return rewards
69
+
70
+ def len_reward(completions, **kwargs):
71
+ """Reward shorter completions to encourage efficiency."""
72
+ completion_contents = [completion[0]["content"] for completion in completions]
73
+ lengths = [len(content) for content in completion_contents]
74
+
75
+ # If all completions have the same length, return neutral rewards
76
+ if min(lengths) == max(lengths):
77
+ return [0.0] * len(completions)
78
+
79
+ # Normalize lengths to [0, 1] range and invert (shorter = higher reward)
80
+ normalized_lengths = [(length - min(lengths)) / (max(lengths) - min(lengths)) for length in lengths]
81
+ rewards = [1.0 - norm_length for norm_length in normalized_lengths]
82
+
83
+ # Scale to a smaller range to make this a secondary consideration
84
+ scaled_rewards = [0.2 * reward for reward in rewards]
85
+
86
+ return scaled_rewards
87
+
88
+ def code_ratio_reward(completions, **kwargs):
89
+ """Reward appropriate code-to-text ratio."""
90
+ completion_contents = [completion[0]["content"] for completion in completions]
91
+ rewards = []
92
+
93
+ for content in completion_contents:
94
+ # Extract all code blocks
95
+ code_blocks = re.findall(r'<code>(.*?)</code>', content, re.DOTALL)
96
+ total_code_length = sum(len(code) for code in code_blocks)
97
+ total_length = len(content)
98
+
99
+ if total_length == 0:
100
+ rewards.append(0.0)
101
+ continue
102
+
103
+ code_ratio = total_code_length / total_length
104
+
105
+ # Reward an optimal ratio range (e.g., 0.2 to 0.4)
106
+ if 0.2 <= code_ratio <= 0.4:
107
+ rewards.append(0.3) # Full reward
108
+ elif 0.1 <= code_ratio < 0.2 or 0.4 < code_ratio <= 0.5:
109
+ rewards.append(0.2) # Partial reward
110
+ elif 0.05 <= code_ratio < 0.1 or 0.5 < code_ratio <= 0.6:
111
+ rewards.append(0.1) # Minimal reward
112
+ else:
113
+ rewards.append(0.0) # No reward
114
+
115
+ return rewards
116
+
117
+ def code_timing_reward(completions, **kwargs):
118
+ """Reward for invoking code at appropriate points in the reasoning process."""
119
+ completion_contents = [completion[0]["content"] for completion in completions]
120
+ rewards = []
121
+
122
+ for content in completion_contents:
123
+ # Calculate relative position of first code block
124
+ first_code_pos = content.find('<code>')
125
+ if first_code_pos == -1:
126
+ rewards.append(0.0)
127
+ continue
128
+
129
+ relative_pos = first_code_pos / len(content)
130
+
131
+ # Reward early-to-middle code invocation (between 10% and 40% of the way through)
132
+ if 0.1 <= relative_pos <= 0.4:
133
+ rewards.append(0.3)
134
+ elif 0.05 <= relative_pos < 0.1 or 0.4 < relative_pos <= 0.5:
135
+ rewards.append(0.2)
136
+ elif 0.0 <= relative_pos < 0.05 or 0.5 < relative_pos <= 0.7:
137
+ rewards.append(0.1)
138
+ else:
139
+ rewards.append(0.0)
140
+
141
+ return rewards
142
+
143
+ def get_reward_funcs(script_args) -> list[Callable]:
144
+ """Create a registry of available reward functions and return those specified in script_args."""
145
+ REWARD_FUNCS_REGISTRY = {
146
+ "accuracy": accuracy_reward,
147
+ "format": format_reward,
148
+ "code_execution": code_execution_reward,
149
+ "length": len_reward,
150
+ "code_ratio": code_ratio_reward,
151
+ "code_timing": code_timing_reward,
152
+ }
153
+
154
+ # Get the specified reward functions
155
+ reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
156
+ return reward_funcs