ReTool-Implementation / src /retool_trainer.py
bird-of-paradise's picture
Upload 5 files
cb481ca verified
raw
history blame
27.5 kB
from typing import Any, Callable, Optional, Union
from collections import defaultdict
import re
import profiling_decorator
import datasets
import torch
import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, Sampler
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
is_wandb_available,
PreTrainedTokenizer,
)
class ReToolTrainer(Trainer): # Change this line
def __init__(
self,
model: Optional[PreTrainedModel] = None,
processing_class: Optional[PreTrainedTokenizerBase] = None,
args: Optional[transformers.TrainingArguments] = None,
reward_funcs: Optional[list[Callable]] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
# ReTool specific parameters - same as before
eos_id: Optional[int] = None,
interpreter_id: Optional[list[int]] = None,
code_id: Optional[list[int]] = None,
max_turns: int = 10,
max_completion_length: int = 1024,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
min_p: Optional[float] = None,
mask_truncated_completions: bool = True,
**kwargs
):
# Initialize parent Trainer (simpler call)
super().__init__(
model=model,
tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
**kwargs
)
# Store processing_class for compatibility
self.processing_class = processing_class or self.tokenizer
# Add reward function handling (since Trainer doesn't have this)
self.reward_funcs = reward_funcs or [self._binary_reward_function]
# Rest of the ReTool-specific code stays exactly the same!
self.eos_id = eos_id or self.processing_class.eos_token_id
# ReTool specific attributes
self.eos_id = eos_id or self.processing_class.eos_token_id
self.interpreter_id = interpreter_id or self._get_interpreter_token_ids()
self.code_id = code_id or self._get_code_token_ids()
self.max_turns = max_turns
self.max_completion_length = max_completion_length
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.mask_truncated_completions = mask_truncated_completions
# ReTool specific logging
self.reward_func_names = ["binary_correctness"]
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._textual_logs = {
"prompt": [],
"completion": [],
"rewards": {"binary_correctness": []}
}
# Generation configuration for ReTool
self.generation_config = GenerationConfig(
max_new_tokens=50, # Per turn, not total
do_sample=True,
pad_token_id=self.processing_class.pad_token_id,
bos_token_id=self.processing_class.bos_token_id,
eos_token_id=[self.eos_id, self.code_id[1]], # Stop on EOS or </code>
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
return_dict_in_generate=True,
use_cache=True,
)
def _get_interpreter_token_ids(self) -> list[int]:
"""Get token IDs for <interpreter> and </interpreter> tags."""
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
end_token = self.processing_class.encode("</interpreter>", add_special_tokens=False)[0]
return [start_token, end_token]
def _get_code_token_ids(self) -> list[int]:
"""Get token IDs for <code> and </code> tags."""
start_token = self.processing_class.encode("<code>", add_special_tokens=False)[0]
end_token = self.processing_class.encode("</code>", add_special_tokens=False)[0]
return [start_token, end_token]
def _binary_reward_function(self, prompts, completions, **kwargs) -> list[float]:
"""Default binary reward function for mathematical correctness."""
rewards = []
ground_truths = kwargs.get('ground_truths', [None] * len(completions))
for completion, ground_truth in zip(completions, ground_truths):
if self._is_correct_answer(completion, ground_truth):
rewards.append(1.0)
else:
rewards.append(-1.0)
return rewards
def _execute_code(self, code_block: str) -> str:
"""
Execute code in a sandbox environment.
TODO: Implement actual code execution sandbox.
For now, returns a placeholder.
"""
# Placeholder implementation
return f"Executed: {code_block[:50]}... -> Result: 42"
def _check_equivalence(self, predicted, ground_truth):
"""Simple equivalence check - you can make this more sophisticated later."""
# Simple string comparison for now
return str(predicted).strip() == str(ground_truth).strip()
def _is_correct_answer(self, completion_text, ground_truth):
import re
# Look for boxed answer
match = re.search(r'\\boxed\{([^}]+)\}', completion_text)
if match:
predicted = match.group(1)
return self._check_equivalence(predicted, ground_truth)
return False
def _compute_rewards_and_advantages(self, completions_text, ground_truths, device):
"""Simplified reward and advantage computation for ReTool."""
# Compute binary rewards
rewards = []
for completion_text, ground_truth in zip(completions_text, ground_truths):
if self._is_correct_answer(completion_text, ground_truth):
rewards.append(1.0)
else:
rewards.append(-1.0)
# For now: advantages = rewards (skip group normalization)
advantages = torch.tensor(rewards, dtype=torch.float32, device=device)
return advantages
def _retool_generate_with_interpreter(
self,
prompt_ids_batch: torch.Tensor, # Full batch of prompts
attention_mask_batch: torch.Tensor, # Full batch of attention masks for prompts
#tokenizer: PreTrainedTokenizer, # use self.processiing_class for Tokenizer
eos_id: int, # True end-of-sequence token ID
interpreter_id: list[int], # [start_id, end_id]
code_id: list[int], # [start_id, end_id]
max_turns: int = 10
) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]:
batch_size = prompt_ids_batch.size(0)
batch_completion = []
batch_interpreter_positions = []
for i in range(batch_size): # Process each item in the batch
# --- Initialization for the current sequence ---
current_input_id = prompt_ids_batch[i:i+1] # Initial input is the prompt
current_attention_mask = attention_mask_batch[i:i+1]
current_kv = None
# NEW: Track only the completion part (no prompt)
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
interpreter_positions = []
for turn_idx in range(max_turns):
# --- Stage 1: LM generates text ---
model_outputs = self.model.generate(
input_ids=current_input_id,
attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
past_key_values=current_kv,
generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
# max_new_tokens should be set in self.generation_config appropriately for a segment
)
# Update current_full_ids to the new complete sequence
current_full_ids = model_outputs.sequences
# Newly generated tokens by the LM in THIS step
completion_id = current_full_ids[:, current_input_id.size(1):]
# Add to completion tracking (excludes prompt)
cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
# Update current_input_id for the next generation step
# Update current_attention_mask: it was for (history + current_input_id),
# now append 1s for completion_id
current_attention_mask = torch.cat([
current_attention_mask,
torch.ones_like(completion_id)
], dim=1)
current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
last_token_id = current_full_ids[0, -1].item()
if last_token_id == eos_id or turn_idx == max_turns - 1:
batch_completion.append(cumulative_completion_ids.squeeze(0))
batch_interpreter_positions.append(interpreter_positions) # Note: was batch_interpreter_positions[i] = ...
break
if last_token_id == code_id[1]: # Assuming code_id[1] is the specific ID for </code> last token
# --- Stage 2: Tool Execution ---
# Extract code from the generated sequence
full_text = self.processing_class.decode(current_full_ids[0])
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
if code_match:
code_block = code_match.group(1)
interpreter_text = self._execute_code(code_block) # 👈 To do: code sandbox execution 👈
else:
interpreter_text = "Error: No code found"
formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
interpreter_feedback_id = self.processing_class(
formatted_feedback_text,
return_tensors="pt",
add_special_tokens=False
).input_ids.to(current_full_ids.device)
# Record positions relative to cumulative_completion_ids *before* appending feedback
interpreter_start_idx = cumulative_completion_ids.size(1)
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
# Update attention mask for the appended tool feedback
current_attention_mask = torch.cat([
current_attention_mask,
torch.ones_like(interpreter_feedback_id)
], dim=1)
# Prepare for the next LM generation step:
# The model needs to "process" the tool_output_tokens to update its KV cache.
# The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
# `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
# The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
current_input_id = interpreter_feedback_id
# `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
# The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
else:
# LM stopped for a reason other than EOS or code_end` (e.g., max_new_tokens for the segment)
batch_completion.append(cumulative_completion_ids.squeeze(0))
batch_interpreter_positions.append(interpreter_positions)
# At the end, return full sequence (prompt + completion)
break
else: # Executed if the loop finished due to max_turns without a break
batch_completion.append(cumulative_completion_ids.squeeze(0))
batch_interpreter_positions.append(interpreter_positions)
# Pad sequences in the batch to the same length for returning a single tensor
# This is a common step if you started with a batch loop.
# Alternatively, this function could return a list of tensors if lengths vary.
# For now, assuming you'll handle batch padding outside or return a list.
# The return type `torch.LongTensor` implies a padded batch.
padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id)
return padded_sequences, batch_interpreter_positions
def _create_interpreter_mask(
self,
completion_ids: torch.Tensor,
interpreter_positions: list[list[tuple[int, int]]]
) -> torch.Tensor:
"""
Create interpreter mask from positions.
Args:
completion_ids: Tensor of shape (batch_size, seq_length)
interpreter_positions: List[List[Tuple[start_idx, end_idx]]]
- Indices are relative to completion_ids
- start_idx: inclusive, end_idx: INCLUSIVE (unlike typical Python slicing)
Returns:
interpreter_mask: Tensor of shape (batch_size, seq_length)
1 = model-generated token, 0 = interpreter token
"""
batch_size, seq_length = completion_ids.shape
# Initialize mask with all 1s (assume all tokens are model-generated)
interpreter_mask = torch.ones(batch_size, seq_length, dtype=torch.float, device=completion_ids.device)
# For each sequence in the batch
for batch_idx, positions_in_sequence in enumerate(interpreter_positions):
# For each interpreter section in this sequence
for start_idx, end_idx in positions_in_sequence:
# Clamp indices to valid range
start_idx = max(0, min(start_idx, seq_length - 1))
end_idx = max(0, min(end_idx, seq_length - 1))
# Zero out interpreter tokens (BOTH start and end inclusive)
if start_idx <= end_idx: # Changed from < to <=
interpreter_mask[batch_idx, start_idx:end_idx + 1] = 0 # Changed to end_idx + 1
return interpreter_mask
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# use custom multi-turn-w-tool-use Generate completions
completion_ids, interpreter_positions = self._retool_generate_with_interpreter(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config,
eos_id = self.eos_id, interpreter_id = self.interpreter_id, code_id = self.code_id
)
# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# compute interpreter mask
interpreter_mask = self._create_interpreter_mask(completion_ids, interpreter_positions)
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
# no need to return old_per_token_logps
# Extract ground truths from inputs
ground_truths = [x.get("answer") for x in inputs] # Adjust key name as needed
# Decode completions for reward computation
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
# Compute rewards and advantages
advantages = self._compute_rewards_and_advantages(
completions_text,
ground_truths,
device=device
)
# Log the metrics
if mode == "train":
self.state.num_input_tokens_seen += attention_mask.sum().item() # Skip gather
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
# Log completion lengths
completion_lengths = completion_mask.sum(1) # Skip gather
self._metrics[mode]["completions/mean_length"].append(completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(completion_lengths.float().max().item())
# Log terminated sequences
terminated_with_eos = is_eos.any(dim=1) # Skip gather
term_completion_lengths = completion_lengths[terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
if len(term_completion_lengths) == 0:
term_completion_lengths = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
# Log rewards (simplified for single reward function)
advantages_tensor = advantages
self._metrics[mode]["rewards/binary_correctness/mean"].append(advantages_tensor.mean().item())
self._metrics[mode]["rewards/binary_correctness/std"].append(advantages_tensor.std().item())
# Log texts for debugging
self._textual_logs["prompt"].extend(prompts_text)
self._textual_logs["completion"].extend(completions_text)
self._textual_logs["rewards"]["binary_correctness"].extend(advantages.tolist())
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"interpreter_mask": interpreter_mask,
"advantages": advantages
}
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
all_logps = []
for i in range(0, input_ids.size(0), batch_size):
input_ids_batch = input_ids[i : i + batch_size]
attention_mask_batch = attention_mask[i : i + batch_size]
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
all_logps.append(logps)
return torch.cat(all_logps, dim=0)
@staticmethod
def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
This function is equivalent to the following naive implementation:
```python
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
```
Args:
logits (`torch.Tensor`):
Logits tensor of shape `(..., num_classes)`.
index (`torch.Tensor`):
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
Returns:
`torch.Tensor`:
Gathered log probabilities with the same shape as `index`.
"""
if logits.dtype in [torch.float32, torch.float64]:
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
# loop to reduce peak mem consumption
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
else:
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
per_token_logps = []
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
row_logps = F.log_softmax(row_logits, dim=-1)
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(row_per_token_logps)
per_token_logps = torch.stack(per_token_logps)
return per_token_logps
def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
# Added for ReTool Trainer
interpreter_mask = inputs["interpreter_mask"]
final_mask = interpreter_mask * completion_mask
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
with torch.no_grad():
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, input_ids, attention_mask, logits_to_keep
)
# Compute the KL divergence between the model and the reference model
if self.beta != 0.0:
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
# Compute the loss
advantages = inputs["advantages"]
old_per_token_logps = ref_per_token_logps
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
# For PPO loss
masked_loss = per_token_loss * final_mask
total_valid_tokens = final_mask.sum() + 1e-8 # Avoid division by zero
loss = masked_loss.sum() / total_valid_tokens
""" --- """
# Log the metrics
mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
mean_kl = (per_token_kl * final_mask).sum() / final_mask.sum()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
is_region_clipped = is_low_clipped | is_high_clipped
low_clip = (is_low_clipped * final_mask).sum() / final_mask.sum()
high_clip = (is_high_clipped * final_mask).sum() / final_mask.sum()
clip_ratio = (is_region_clipped * final_mask).sum() / final_mask.sum()
gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
return loss