|
from typing import Any, Callable, Optional, Union |
|
from collections import defaultdict |
|
import re |
|
import profiling_decorator |
|
|
|
import datasets |
|
import torch |
|
import torch.utils.data |
|
import transformers |
|
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed |
|
from datasets import Dataset, IterableDataset |
|
from packaging import version |
|
from torch import nn |
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
from torch.utils.data import DataLoader, Sampler |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
Trainer, |
|
TrainerCallback, |
|
is_wandb_available, |
|
PreTrainedTokenizer, |
|
) |
|
|
|
|
|
|
|
class ReToolTrainer(Trainer): |
|
|
|
def __init__( |
|
self, |
|
model: Optional[PreTrainedModel] = None, |
|
processing_class: Optional[PreTrainedTokenizerBase] = None, |
|
args: Optional[transformers.TrainingArguments] = None, |
|
reward_funcs: Optional[list[Callable]] = None, |
|
train_dataset: Optional[Dataset] = None, |
|
eval_dataset: Optional[Dataset] = None, |
|
|
|
eos_id: Optional[int] = None, |
|
interpreter_id: Optional[list[int]] = None, |
|
code_id: Optional[list[int]] = None, |
|
max_turns: int = 10, |
|
max_completion_length: int = 1024, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
min_p: Optional[float] = None, |
|
mask_truncated_completions: bool = True, |
|
**kwargs |
|
): |
|
|
|
super().__init__( |
|
model=model, |
|
tokenizer=processing_class, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
**kwargs |
|
) |
|
|
|
|
|
|
|
self.processing_class = processing_class or self.tokenizer |
|
|
|
|
|
self.reward_funcs = reward_funcs or [self._binary_reward_function] |
|
|
|
|
|
self.eos_id = eos_id or self.processing_class.eos_token_id |
|
|
|
|
|
|
|
self.eos_id = eos_id or self.processing_class.eos_token_id |
|
self.interpreter_id = interpreter_id or self._get_interpreter_token_ids() |
|
self.code_id = code_id or self._get_code_token_ids() |
|
self.max_turns = max_turns |
|
self.max_completion_length = max_completion_length |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
self.top_k = top_k |
|
self.min_p = min_p |
|
self.mask_truncated_completions = mask_truncated_completions |
|
|
|
|
|
self.reward_func_names = ["binary_correctness"] |
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} |
|
self._textual_logs = { |
|
"prompt": [], |
|
"completion": [], |
|
"rewards": {"binary_correctness": []} |
|
} |
|
|
|
|
|
self.generation_config = GenerationConfig( |
|
max_new_tokens=50, |
|
do_sample=True, |
|
pad_token_id=self.processing_class.pad_token_id, |
|
bos_token_id=self.processing_class.bos_token_id, |
|
eos_token_id=[self.eos_id, self.code_id[1]], |
|
temperature=self.temperature, |
|
top_p=self.top_p, |
|
top_k=self.top_k, |
|
min_p=self.min_p, |
|
return_dict_in_generate=True, |
|
use_cache=True, |
|
) |
|
|
|
|
|
def _get_interpreter_token_ids(self) -> list[int]: |
|
"""Get token IDs for <interpreter> and </interpreter> tags.""" |
|
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0] |
|
end_token = self.processing_class.encode("</interpreter>", add_special_tokens=False)[0] |
|
return [start_token, end_token] |
|
|
|
def _get_code_token_ids(self) -> list[int]: |
|
"""Get token IDs for <code> and </code> tags.""" |
|
start_token = self.processing_class.encode("<code>", add_special_tokens=False)[0] |
|
end_token = self.processing_class.encode("</code>", add_special_tokens=False)[0] |
|
return [start_token, end_token] |
|
|
|
def _binary_reward_function(self, prompts, completions, **kwargs) -> list[float]: |
|
"""Default binary reward function for mathematical correctness.""" |
|
rewards = [] |
|
ground_truths = kwargs.get('ground_truths', [None] * len(completions)) |
|
|
|
for completion, ground_truth in zip(completions, ground_truths): |
|
if self._is_correct_answer(completion, ground_truth): |
|
rewards.append(1.0) |
|
else: |
|
rewards.append(-1.0) |
|
return rewards |
|
|
|
def _execute_code(self, code_block: str) -> str: |
|
""" |
|
Execute code in a sandbox environment. |
|
|
|
TODO: Implement actual code execution sandbox. |
|
For now, returns a placeholder. |
|
""" |
|
|
|
return f"Executed: {code_block[:50]}... -> Result: 42" |
|
|
|
|
|
def _check_equivalence(self, predicted, ground_truth): |
|
"""Simple equivalence check - you can make this more sophisticated later.""" |
|
|
|
return str(predicted).strip() == str(ground_truth).strip() |
|
|
|
def _is_correct_answer(self, completion_text, ground_truth): |
|
import re |
|
|
|
match = re.search(r'\\boxed\{([^}]+)\}', completion_text) |
|
if match: |
|
predicted = match.group(1) |
|
return self._check_equivalence(predicted, ground_truth) |
|
return False |
|
|
|
def _compute_rewards_and_advantages(self, completions_text, ground_truths, device): |
|
"""Simplified reward and advantage computation for ReTool.""" |
|
|
|
|
|
rewards = [] |
|
for completion_text, ground_truth in zip(completions_text, ground_truths): |
|
if self._is_correct_answer(completion_text, ground_truth): |
|
rewards.append(1.0) |
|
else: |
|
rewards.append(-1.0) |
|
|
|
|
|
advantages = torch.tensor(rewards, dtype=torch.float32, device=device) |
|
|
|
return advantages |
|
|
|
|
|
def _retool_generate_with_interpreter( |
|
self, |
|
prompt_ids_batch: torch.Tensor, |
|
attention_mask_batch: torch.Tensor, |
|
|
|
eos_id: int, |
|
interpreter_id: list[int], |
|
code_id: list[int], |
|
max_turns: int = 10 |
|
) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]: |
|
|
|
batch_size = prompt_ids_batch.size(0) |
|
batch_completion = [] |
|
batch_interpreter_positions = [] |
|
|
|
for i in range(batch_size): |
|
|
|
current_input_id = prompt_ids_batch[i:i+1] |
|
current_attention_mask = attention_mask_batch[i:i+1] |
|
current_kv = None |
|
|
|
|
|
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device) |
|
interpreter_positions = [] |
|
|
|
for turn_idx in range(max_turns): |
|
|
|
model_outputs = self.model.generate( |
|
input_ids=current_input_id, |
|
attention_mask=current_attention_mask, |
|
eos_token_id=[eos_id, code_id[1]], |
|
past_key_values=current_kv, |
|
generation_config=self.generation_config, |
|
|
|
) |
|
|
|
|
|
current_full_ids = model_outputs.sequences |
|
|
|
|
|
completion_id = current_full_ids[:, current_input_id.size(1):] |
|
|
|
|
|
cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1) |
|
|
|
|
|
|
|
|
|
current_attention_mask = torch.cat([ |
|
current_attention_mask, |
|
torch.ones_like(completion_id) |
|
], dim=1) |
|
|
|
current_kv = model_outputs.past_key_values |
|
|
|
last_token_id = current_full_ids[0, -1].item() |
|
|
|
if last_token_id == eos_id or turn_idx == max_turns - 1: |
|
batch_completion.append(cumulative_completion_ids.squeeze(0)) |
|
batch_interpreter_positions.append(interpreter_positions) |
|
break |
|
|
|
if last_token_id == code_id[1]: |
|
|
|
|
|
full_text = self.processing_class.decode(current_full_ids[0]) |
|
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL) |
|
if code_match: |
|
code_block = code_match.group(1) |
|
interpreter_text = self._execute_code(code_block) |
|
else: |
|
interpreter_text = "Error: No code found" |
|
|
|
formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}" |
|
|
|
interpreter_feedback_id = self.processing_class( |
|
formatted_feedback_text, |
|
return_tensors="pt", |
|
add_special_tokens=False |
|
).input_ids.to(current_full_ids.device) |
|
|
|
|
|
|
|
interpreter_start_idx = cumulative_completion_ids.size(1) |
|
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) |
|
interpreter_end_idx = cumulative_completion_ids.size(1) - 1 |
|
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx)) |
|
|
|
|
|
current_attention_mask = torch.cat([ |
|
current_attention_mask, |
|
torch.ones_like(interpreter_feedback_id) |
|
], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
current_input_id = interpreter_feedback_id |
|
|
|
|
|
else: |
|
|
|
batch_completion.append(cumulative_completion_ids.squeeze(0)) |
|
batch_interpreter_positions.append(interpreter_positions) |
|
|
|
break |
|
else: |
|
batch_completion.append(cumulative_completion_ids.squeeze(0)) |
|
batch_interpreter_positions.append(interpreter_positions) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id) |
|
|
|
return padded_sequences, batch_interpreter_positions |
|
|
|
|
|
|
|
def _create_interpreter_mask( |
|
self, |
|
completion_ids: torch.Tensor, |
|
interpreter_positions: list[list[tuple[int, int]]] |
|
) -> torch.Tensor: |
|
""" |
|
Create interpreter mask from positions. |
|
|
|
Args: |
|
completion_ids: Tensor of shape (batch_size, seq_length) |
|
interpreter_positions: List[List[Tuple[start_idx, end_idx]]] |
|
- Indices are relative to completion_ids |
|
- start_idx: inclusive, end_idx: INCLUSIVE (unlike typical Python slicing) |
|
|
|
Returns: |
|
interpreter_mask: Tensor of shape (batch_size, seq_length) |
|
1 = model-generated token, 0 = interpreter token |
|
""" |
|
batch_size, seq_length = completion_ids.shape |
|
|
|
|
|
interpreter_mask = torch.ones(batch_size, seq_length, dtype=torch.float, device=completion_ids.device) |
|
|
|
|
|
for batch_idx, positions_in_sequence in enumerate(interpreter_positions): |
|
|
|
for start_idx, end_idx in positions_in_sequence: |
|
|
|
start_idx = max(0, min(start_idx, seq_length - 1)) |
|
end_idx = max(0, min(end_idx, seq_length - 1)) |
|
|
|
|
|
if start_idx <= end_idx: |
|
interpreter_mask[batch_idx, start_idx:end_idx + 1] = 0 |
|
|
|
return interpreter_mask |
|
|
|
|
|
def _generate_and_score_completions( |
|
self, inputs: list[dict[str, Union[torch.Tensor, Any]]] |
|
) -> dict[str, Union[torch.Tensor, Any]]: |
|
|
|
device = self.accelerator.device |
|
mode = "train" if self.model.training else "eval" |
|
|
|
prompts = [x["prompt"] for x in inputs] |
|
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] |
|
prompt_inputs = self.processing_class( |
|
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False |
|
) |
|
prompt_inputs = super()._prepare_inputs(prompt_inputs) |
|
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] |
|
|
|
if self.max_prompt_length is not None: |
|
prompt_ids = prompt_ids[:, -self.max_prompt_length :] |
|
prompt_mask = prompt_mask[:, -self.max_prompt_length :] |
|
|
|
|
|
|
|
completion_ids, interpreter_positions = self._retool_generate_with_interpreter( |
|
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config, |
|
eos_id = self.eos_id, interpreter_id = self.interpreter_id, code_id = self.code_id |
|
) |
|
|
|
|
|
|
|
is_eos = completion_ids == self.processing_class.eos_token_id |
|
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) |
|
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] |
|
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) |
|
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() |
|
|
|
|
|
|
|
interpreter_mask = self._create_interpreter_mask(completion_ids, interpreter_positions) |
|
|
|
|
|
|
|
if self.mask_truncated_completions: |
|
truncated_completions = ~is_eos.any(dim=1) |
|
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() |
|
|
|
|
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
ground_truths = [x.get("answer") for x in inputs] |
|
|
|
|
|
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) |
|
|
|
|
|
advantages = self._compute_rewards_and_advantages( |
|
completions_text, |
|
ground_truths, |
|
device=device |
|
) |
|
|
|
|
|
|
|
if mode == "train": |
|
self.state.num_input_tokens_seen += attention_mask.sum().item() |
|
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] |
|
|
|
|
|
completion_lengths = completion_mask.sum(1) |
|
self._metrics[mode]["completions/mean_length"].append(completion_lengths.float().mean().item()) |
|
self._metrics[mode]["completions/min_length"].append(completion_lengths.float().min().item()) |
|
self._metrics[mode]["completions/max_length"].append(completion_lengths.float().max().item()) |
|
|
|
|
|
terminated_with_eos = is_eos.any(dim=1) |
|
term_completion_lengths = completion_lengths[terminated_with_eos] |
|
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths) |
|
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) |
|
|
|
if len(term_completion_lengths) == 0: |
|
term_completion_lengths = torch.zeros(1, device=device) |
|
|
|
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) |
|
|
|
|
|
advantages_tensor = advantages |
|
self._metrics[mode]["rewards/binary_correctness/mean"].append(advantages_tensor.mean().item()) |
|
self._metrics[mode]["rewards/binary_correctness/std"].append(advantages_tensor.std().item()) |
|
|
|
|
|
|
|
self._textual_logs["prompt"].extend(prompts_text) |
|
self._textual_logs["completion"].extend(completions_text) |
|
self._textual_logs["rewards"]["binary_correctness"].extend(advantages.tolist()) |
|
|
|
return { |
|
"prompt_ids": prompt_ids, |
|
"prompt_mask": prompt_mask, |
|
"completion_ids": completion_ids, |
|
"completion_mask": completion_mask, |
|
"interpreter_mask": interpreter_mask, |
|
"advantages": advantages |
|
} |
|
|
|
|
|
|
|
@profiling_decorator |
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor: |
|
batch_size = batch_size or input_ids.size(0) |
|
all_logps = [] |
|
for i in range(0, input_ids.size(0), batch_size): |
|
input_ids_batch = input_ids[i : i + batch_size] |
|
attention_mask_batch = attention_mask[i : i + batch_size] |
|
|
|
|
|
logits = model( |
|
input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1 |
|
).logits |
|
logits = logits[:, :-1, :] |
|
input_ids_batch = input_ids_batch[:, -logits_to_keep:] |
|
|
|
|
|
logits = logits[:, -logits_to_keep:] |
|
|
|
|
|
logits = logits / self.temperature |
|
logps = selective_log_softmax(logits, input_ids_batch) |
|
all_logps.append(logps) |
|
return torch.cat(all_logps, dim=0) |
|
|
|
|
|
@staticmethod |
|
def selective_log_softmax(logits, index): |
|
""" |
|
A memory-efficient implementation of the common `log_softmax -> gather` operation. |
|
|
|
This function is equivalent to the following naive implementation: |
|
```python |
|
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) |
|
``` |
|
|
|
Args: |
|
logits (`torch.Tensor`): |
|
Logits tensor of shape `(..., num_classes)`. |
|
index (`torch.Tensor`): |
|
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
Gathered log probabilities with the same shape as `index`. |
|
""" |
|
if logits.dtype in [torch.float32, torch.float64]: |
|
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) |
|
|
|
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) |
|
per_token_logps = selected_logits - logsumexp_values |
|
else: |
|
|
|
per_token_logps = [] |
|
for row_logits, row_labels in zip(logits, index): |
|
row_logps = F.log_softmax(row_logits, dim=-1) |
|
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) |
|
per_token_logps.append(row_per_token_logps) |
|
per_token_logps = torch.stack(per_token_logps) |
|
return per_token_logps |
|
|
|
|
|
def _compute_loss(self, model, inputs): |
|
|
|
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] |
|
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] |
|
|
|
|
|
interpreter_mask = inputs["interpreter_mask"] |
|
final_mask = interpreter_mask * completion_mask |
|
|
|
input_ids = torch.cat([prompt_ids, completion_ids], dim=1) |
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) |
|
logits_to_keep = completion_ids.size(1) |
|
|
|
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) |
|
|
|
with torch.no_grad(): |
|
ref_per_token_logps = self._get_per_token_logps( |
|
self.ref_model, input_ids, attention_mask, logits_to_keep |
|
) |
|
|
|
if self.beta != 0.0: |
|
per_token_kl = ( |
|
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 |
|
) |
|
|
|
|
|
advantages = inputs["advantages"] |
|
|
|
old_per_token_logps = ref_per_token_logps |
|
coef_1 = torch.exp(per_token_logps - old_per_token_logps) |
|
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) |
|
|
|
per_token_loss1 = coef_1 * advantages.unsqueeze(1) |
|
per_token_loss2 = coef_2 * advantages.unsqueeze(1) |
|
per_token_loss = -torch.min(per_token_loss1, per_token_loss2) |
|
if self.beta != 0.0: |
|
per_token_loss = per_token_loss + self.beta * per_token_kl |
|
|
|
|
|
|
|
masked_loss = per_token_loss * final_mask |
|
total_valid_tokens = final_mask.sum() + 1e-8 |
|
loss = masked_loss.sum() / total_valid_tokens |
|
|
|
""" --- """ |
|
|
|
|
|
mode = "train" if self.model.training else "eval" |
|
|
|
if self.beta != 0.0: |
|
mean_kl = (per_token_kl * final_mask).sum() / final_mask.sum() |
|
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) |
|
|
|
|
|
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) |
|
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) |
|
is_region_clipped = is_low_clipped | is_high_clipped |
|
|
|
low_clip = (is_low_clipped * final_mask).sum() / final_mask.sum() |
|
high_clip = (is_high_clipped * final_mask).sum() / final_mask.sum() |
|
clip_ratio = (is_region_clipped * final_mask).sum() / final_mask.sum() |
|
|
|
gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) |
|
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) |
|
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) |
|
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) |
|
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) |
|
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) |
|
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) |
|
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) |
|
return loss |