|
import re |
|
from typing import List, Dict, Optional, Callable |
|
import numpy as np |
|
|
|
def format_reward(completions, **kwargs): |
|
"""Reward function that checks if the code is enclosed within <code> and </code> tags, |
|
and the final answer is enclosed within <answer> and </answer> tags.""" |
|
pattern = r".*?<code>.*?</code>.*?<answer>.*?</answer>.*?" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
matches = [re.search(pattern, content, re.DOTALL | re.MULTILINE) is not None for content in completion_contents] |
|
return [1.0 if match else 0.0 for match in matches] |
|
|
|
def accuracy_reward(completions: list[list[dict[str, str]]], correct_answers: list[str], **kwargs) -> list[Optional[float]]: |
|
"""Reward function that checks if the completion's answer matches the ground truth.""" |
|
contents = [completion[0]["content"] for completion in completions] |
|
rewards = [] |
|
|
|
for content, correct_answer in zip(contents, correct_answers): |
|
|
|
answer_match = re.search(r'<answer>\\boxed{(.*?)}</answer>', content, re.DOTALL) |
|
if not answer_match: |
|
rewards.append(0.0) |
|
continue |
|
|
|
extracted_answer = answer_match.group(1).strip() |
|
|
|
|
|
|
|
if extracted_answer == correct_answer: |
|
rewards.append(1.0) |
|
else: |
|
rewards.append(0.0) |
|
|
|
return rewards |
|
|
|
def code_execution_reward(completions, **kwargs): |
|
"""Reward function that checks if the code execution was successful.""" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
|
|
error_patterns = [ |
|
r'<interpreter>.*?Error.*?</interpreter>', |
|
r'<interpreter>.*?Exception.*?</interpreter>', |
|
r'<interpreter>.*?Traceback.*?</interpreter>' |
|
] |
|
|
|
rewards = [] |
|
for content in completion_contents: |
|
|
|
code_blocks = re.findall(r'<code>.*?</code>\s*<interpreter>(.*?)</interpreter>', content, re.DOTALL) |
|
if not code_blocks: |
|
rewards.append(0.0) |
|
continue |
|
|
|
|
|
error_count = 0 |
|
for interpreter_output in code_blocks: |
|
has_error = any(re.search(pattern, interpreter_output, re.DOTALL) for pattern in error_patterns) |
|
if has_error: |
|
error_count += 1 |
|
|
|
|
|
if len(code_blocks) == 0: |
|
rewards.append(0.0) |
|
else: |
|
success_rate = 1.0 - (error_count / len(code_blocks)) |
|
rewards.append(success_rate) |
|
|
|
return rewards |
|
|
|
def len_reward(completions, **kwargs): |
|
"""Reward shorter completions to encourage efficiency.""" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
lengths = [len(content) for content in completion_contents] |
|
|
|
|
|
if min(lengths) == max(lengths): |
|
return [0.0] * len(completions) |
|
|
|
|
|
normalized_lengths = [(length - min(lengths)) / (max(lengths) - min(lengths)) for length in lengths] |
|
rewards = [1.0 - norm_length for norm_length in normalized_lengths] |
|
|
|
|
|
scaled_rewards = [0.2 * reward for reward in rewards] |
|
|
|
return scaled_rewards |
|
|
|
def code_ratio_reward(completions, **kwargs): |
|
"""Reward appropriate code-to-text ratio.""" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
rewards = [] |
|
|
|
for content in completion_contents: |
|
|
|
code_blocks = re.findall(r'<code>(.*?)</code>', content, re.DOTALL) |
|
total_code_length = sum(len(code) for code in code_blocks) |
|
total_length = len(content) |
|
|
|
if total_length == 0: |
|
rewards.append(0.0) |
|
continue |
|
|
|
code_ratio = total_code_length / total_length |
|
|
|
|
|
if 0.2 <= code_ratio <= 0.4: |
|
rewards.append(0.3) |
|
elif 0.1 <= code_ratio < 0.2 or 0.4 < code_ratio <= 0.5: |
|
rewards.append(0.2) |
|
elif 0.05 <= code_ratio < 0.1 or 0.5 < code_ratio <= 0.6: |
|
rewards.append(0.1) |
|
else: |
|
rewards.append(0.0) |
|
|
|
return rewards |
|
|
|
def code_timing_reward(completions, **kwargs): |
|
"""Reward for invoking code at appropriate points in the reasoning process.""" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
rewards = [] |
|
|
|
for content in completion_contents: |
|
|
|
first_code_pos = content.find('<code>') |
|
if first_code_pos == -1: |
|
rewards.append(0.0) |
|
continue |
|
|
|
relative_pos = first_code_pos / len(content) |
|
|
|
|
|
if 0.1 <= relative_pos <= 0.4: |
|
rewards.append(0.3) |
|
elif 0.05 <= relative_pos < 0.1 or 0.4 < relative_pos <= 0.5: |
|
rewards.append(0.2) |
|
elif 0.0 <= relative_pos < 0.05 or 0.5 < relative_pos <= 0.7: |
|
rewards.append(0.1) |
|
else: |
|
rewards.append(0.0) |
|
|
|
return rewards |
|
|
|
def get_reward_funcs(script_args) -> list[Callable]: |
|
"""Create a registry of available reward functions and return those specified in script_args.""" |
|
REWARD_FUNCS_REGISTRY = { |
|
"accuracy": accuracy_reward, |
|
"format": format_reward, |
|
"code_execution": code_execution_reward, |
|
"length": len_reward, |
|
"code_ratio": code_ratio_reward, |
|
"code_timing": code_timing_reward, |
|
} |
|
|
|
|
|
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] |
|
return reward_funcs |