from typing import Any, Callable, Optional, Union from collections import defaultdict import re import profiling_decorator import datasets import torch import torch.utils.data import transformers from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset from packaging import version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainerCallback, is_wandb_available, PreTrainedTokenizer, ) class ReToolTrainer(Trainer): # Change this line def __init__( self, model: Optional[PreTrainedModel] = None, processing_class: Optional[PreTrainedTokenizerBase] = None, args: Optional[transformers.TrainingArguments] = None, reward_funcs: Optional[list[Callable]] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, # ReTool specific parameters - same as before eos_id: Optional[int] = None, interpreter_id: Optional[list[int]] = None, code_id: Optional[list[int]] = None, max_turns: int = 10, max_completion_length: int = 1024, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, min_p: Optional[float] = None, mask_truncated_completions: bool = True, **kwargs ): # Initialize parent Trainer (simpler call) super().__init__( model=model, tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class' args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, **kwargs ) # Store processing_class for compatibility self.processing_class = processing_class or self.tokenizer # Add reward function handling (since Trainer doesn't have this) self.reward_funcs = reward_funcs or [self._binary_reward_function] # Rest of the ReTool-specific code stays exactly the same! self.eos_id = eos_id or self.processing_class.eos_token_id # ReTool specific attributes self.eos_id = eos_id or self.processing_class.eos_token_id self.interpreter_id = interpreter_id or self._get_interpreter_token_ids() self.code_id = code_id or self._get_code_token_ids() self.max_turns = max_turns self.max_completion_length = max_completion_length self.temperature = temperature self.top_p = top_p self.top_k = top_k self.min_p = min_p self.mask_truncated_completions = mask_truncated_completions # ReTool specific logging self.reward_func_names = ["binary_correctness"] self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._textual_logs = { "prompt": [], "completion": [], "rewards": {"binary_correctness": []} } # Generation configuration for ReTool self.generation_config = GenerationConfig( max_new_tokens=50, # Per turn, not total do_sample=True, pad_token_id=self.processing_class.pad_token_id, bos_token_id=self.processing_class.bos_token_id, eos_token_id=[self.eos_id, self.code_id[1]], # Stop on EOS or temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, min_p=self.min_p, return_dict_in_generate=True, use_cache=True, ) def _get_interpreter_token_ids(self) -> list[int]: """Get token IDs for and tags.""" start_token = self.processing_class.encode("", add_special_tokens=False)[0] end_token = self.processing_class.encode("", add_special_tokens=False)[0] return [start_token, end_token] def _get_code_token_ids(self) -> list[int]: """Get token IDs for and tags.""" start_token = self.processing_class.encode("", add_special_tokens=False)[0] end_token = self.processing_class.encode("", add_special_tokens=False)[0] return [start_token, end_token] def _binary_reward_function(self, prompts, completions, **kwargs) -> list[float]: """Default binary reward function for mathematical correctness.""" rewards = [] ground_truths = kwargs.get('ground_truths', [None] * len(completions)) for completion, ground_truth in zip(completions, ground_truths): if self._is_correct_answer(completion, ground_truth): rewards.append(1.0) else: rewards.append(-1.0) return rewards def _execute_code(self, code_block: str) -> str: """ Execute code in a sandbox environment. TODO: Implement actual code execution sandbox. For now, returns a placeholder. """ # Placeholder implementation return f"Executed: {code_block[:50]}... -> Result: 42" def _check_equivalence(self, predicted, ground_truth): """Simple equivalence check - you can make this more sophisticated later.""" # Simple string comparison for now return str(predicted).strip() == str(ground_truth).strip() def _is_correct_answer(self, completion_text, ground_truth): import re # Look for boxed answer match = re.search(r'\\boxed\{([^}]+)\}', completion_text) if match: predicted = match.group(1) return self._check_equivalence(predicted, ground_truth) return False def _compute_rewards_and_advantages(self, completions_text, ground_truths, device): """Simplified reward and advantage computation for ReTool.""" # Compute binary rewards rewards = [] for completion_text, ground_truth in zip(completions_text, ground_truths): if self._is_correct_answer(completion_text, ground_truth): rewards.append(1.0) else: rewards.append(-1.0) # For now: advantages = rewards (skip group normalization) advantages = torch.tensor(rewards, dtype=torch.float32, device=device) return advantages def _retool_generate_with_interpreter( self, prompt_ids_batch: torch.Tensor, # Full batch of prompts attention_mask_batch: torch.Tensor, # Full batch of attention masks for prompts #tokenizer: PreTrainedTokenizer, # use self.processiing_class for Tokenizer eos_id: int, # True end-of-sequence token ID interpreter_id: list[int], # [start_id, end_id] code_id: list[int], # [start_id, end_id] max_turns: int = 10 ) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]: batch_size = prompt_ids_batch.size(0) batch_completion = [] batch_interpreter_positions = [] for i in range(batch_size): # Process each item in the batch # --- Initialization for the current sequence --- current_input_id = prompt_ids_batch[i:i+1] # Initial input is the prompt current_attention_mask = attention_mask_batch[i:i+1] current_kv = None # NEW: Track only the completion part (no prompt) cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device) interpreter_positions = [] for turn_idx in range(max_turns): # --- Stage 1: LM generates text --- model_outputs = self.model.generate( input_ids=current_input_id, attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id) eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be 's last token ID past_key_values=current_kv, generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True # max_new_tokens should be set in self.generation_config appropriately for a segment ) # Update current_full_ids to the new complete sequence current_full_ids = model_outputs.sequences # Newly generated tokens by the LM in THIS step completion_id = current_full_ids[:, current_input_id.size(1):] # Add to completion tracking (excludes prompt) cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1) # Update current_input_id for the next generation step # Update current_attention_mask: it was for (history + current_input_id), # now append 1s for completion_id current_attention_mask = torch.cat([ current_attention_mask, torch.ones_like(completion_id) ], dim=1) current_kv = model_outputs.past_key_values # Cache for the new current_full_ids last_token_id = current_full_ids[0, -1].item() if last_token_id == eos_id or turn_idx == max_turns - 1: batch_completion.append(cumulative_completion_ids.squeeze(0)) batch_interpreter_positions.append(interpreter_positions) # Note: was batch_interpreter_positions[i] = ... break if last_token_id == code_id[1]: # Assuming code_id[1] is the specific ID for last token # --- Stage 2: Tool Execution --- # Extract code from the generated sequence full_text = self.processing_class.decode(current_full_ids[0]) code_match = re.search(r'(.*?)', full_text, re.DOTALL) if code_match: code_block = code_match.group(1) interpreter_text = self._execute_code(code_block) # 👈 To do: code sandbox execution 👈 else: interpreter_text = "Error: No code found" formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}" interpreter_feedback_id = self.processing_class( formatted_feedback_text, return_tensors="pt", add_special_tokens=False ).input_ids.to(current_full_ids.device) # Record positions relative to cumulative_completion_ids *before* appending feedback interpreter_start_idx = cumulative_completion_ids.size(1) cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current interpreter_end_idx = cumulative_completion_ids.size(1) - 1 interpreter_positions.append((interpreter_start_idx, interpreter_end_idx)) # Update attention mask for the appended tool feedback current_attention_mask = torch.cat([ current_attention_mask, torch.ones_like(interpreter_feedback_id) ], dim=1) # Prepare for the next LM generation step: # The model needs to "process" the tool_output_tokens to update its KV cache. # The `current_input_id` for the next generate call will be `interpreter_feedback_id`. # `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended. # The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback). current_input_id = interpreter_feedback_id # `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`). # The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`. else: # LM stopped for a reason other than EOS or code_end` (e.g., max_new_tokens for the segment) batch_completion.append(cumulative_completion_ids.squeeze(0)) batch_interpreter_positions.append(interpreter_positions) # At the end, return full sequence (prompt + completion) break else: # Executed if the loop finished due to max_turns without a break batch_completion.append(cumulative_completion_ids.squeeze(0)) batch_interpreter_positions.append(interpreter_positions) # Pad sequences in the batch to the same length for returning a single tensor # This is a common step if you started with a batch loop. # Alternatively, this function could return a list of tensors if lengths vary. # For now, assuming you'll handle batch padding outside or return a list. # The return type `torch.LongTensor` implies a padded batch. padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id) return padded_sequences, batch_interpreter_positions def _create_interpreter_mask( self, completion_ids: torch.Tensor, interpreter_positions: list[list[tuple[int, int]]] ) -> torch.Tensor: """ Create interpreter mask from positions. Args: completion_ids: Tensor of shape (batch_size, seq_length) interpreter_positions: List[List[Tuple[start_idx, end_idx]]] - Indices are relative to completion_ids - start_idx: inclusive, end_idx: INCLUSIVE (unlike typical Python slicing) Returns: interpreter_mask: Tensor of shape (batch_size, seq_length) 1 = model-generated token, 0 = interpreter token """ batch_size, seq_length = completion_ids.shape # Initialize mask with all 1s (assume all tokens are model-generated) interpreter_mask = torch.ones(batch_size, seq_length, dtype=torch.float, device=completion_ids.device) # For each sequence in the batch for batch_idx, positions_in_sequence in enumerate(interpreter_positions): # For each interpreter section in this sequence for start_idx, end_idx in positions_in_sequence: # Clamp indices to valid range start_idx = max(0, min(start_idx, seq_length - 1)) end_idx = max(0, min(end_idx, seq_length - 1)) # Zero out interpreter tokens (BOTH start and end inclusive) if start_idx <= end_idx: # Changed from < to <= interpreter_mask[batch_idx, start_idx:end_idx + 1] = 0 # Changed to end_idx + 1 return interpreter_mask def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] ) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] if self.max_prompt_length is not None: prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] # use custom multi-turn-w-tool-use Generate completions completion_ids, interpreter_positions = self._retool_generate_with_interpreter( prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config, eos_id = self.eos_id, interpreter_id = self.interpreter_id, code_id = self.code_id ) # Mask everything after the first EOS token is_eos = completion_ids == self.processing_class.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # compute interpreter mask interpreter_mask = self._create_interpreter_mask(completion_ids, interpreter_positions) # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) # no need to return old_per_token_logps # Extract ground truths from inputs ground_truths = [x.get("answer") for x in inputs] # Adjust key name as needed # Decode completions for reward computation completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) # Compute rewards and advantages advantages = self._compute_rewards_and_advantages( completions_text, ground_truths, device=device ) # Log the metrics if mode == "train": self.state.num_input_tokens_seen += attention_mask.sum().item() # Skip gather self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths completion_lengths = completion_mask.sum(1) # Skip gather self._metrics[mode]["completions/mean_length"].append(completion_lengths.float().mean().item()) self._metrics[mode]["completions/min_length"].append(completion_lengths.float().min().item()) self._metrics[mode]["completions/max_length"].append(completion_lengths.float().max().item()) # Log terminated sequences terminated_with_eos = is_eos.any(dim=1) # Skip gather term_completion_lengths = completion_lengths[terminated_with_eos] clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths) self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) if len(term_completion_lengths) == 0: term_completion_lengths = torch.zeros(1, device=device) self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) # Log rewards (simplified for single reward function) advantages_tensor = advantages self._metrics[mode]["rewards/binary_correctness/mean"].append(advantages_tensor.mean().item()) self._metrics[mode]["rewards/binary_correctness/std"].append(advantages_tensor.std().item()) # Log texts for debugging self._textual_logs["prompt"].extend(prompts_text) self._textual_logs["completion"].extend(completions_text) self._textual_logs["rewards"]["binary_correctness"].extend(advantages.tolist()) return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "completion_ids": completion_ids, "completion_mask": completion_mask, "interpreter_mask": interpreter_mask, "advantages": advantages } # Get the per-token log probabilities for the completions for the model and the reference model @profiling_decorator def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor: batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1 ).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids_batch = input_ids_batch[:, -logits_to_keep:] # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. # See https://github.com/huggingface/trl/issues/2770 logits = logits[:, -logits_to_keep:] # Divide logits by sampling temperature. # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details logits = logits / self.temperature logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens all_logps.append(logps) return torch.cat(all_logps, dim=0) @staticmethod def selective_log_softmax(logits, index): """ A memory-efficient implementation of the common `log_softmax -> gather` operation. This function is equivalent to the following naive implementation: ```python logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) ``` Args: logits (`torch.Tensor`): Logits tensor of shape `(..., num_classes)`. index (`torch.Tensor`): Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. Returns: `torch.Tensor`: Gathered log probabilities with the same shape as `index`. """ if logits.dtype in [torch.float32, torch.float64]: selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) # loop to reduce peak mem consumption logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) else: # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach per_token_logps = [] for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption row_logps = F.log_softmax(row_logits, dim=-1) row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) per_token_logps.append(row_per_token_logps) per_token_logps = torch.stack(per_token_logps) return per_token_logps def _compute_loss(self, model, inputs): # Compute the per-token log probabilities for the model prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] # Added for ReTool Trainer interpreter_mask = inputs["interpreter_mask"] final_mask = interpreter_mask * completion_mask input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) with torch.no_grad(): ref_per_token_logps = self._get_per_token_logps( self.ref_model, input_ids, attention_mask, logits_to_keep ) # Compute the KL divergence between the model and the reference model if self.beta != 0.0: per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 ) # Compute the loss advantages = inputs["advantages"] old_per_token_logps = ref_per_token_logps coef_1 = torch.exp(per_token_logps - old_per_token_logps) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl # For PPO loss masked_loss = per_token_loss * final_mask total_valid_tokens = final_mask.sum() + 1e-8 # Avoid division by zero loss = masked_loss.sum() / total_valid_tokens """ --- """ # Log the metrics mode = "train" if self.model.training else "eval" if self.beta != 0.0: mean_kl = (per_token_kl * final_mask).sum() / final_mask.sum() self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item()) # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) is_region_clipped = is_low_clipped | is_high_clipped low_clip = (is_low_clipped * final_mask).sum() / final_mask.sum() high_clip = (is_high_clipped * final_mask).sum() / final_mask.sum() clip_ratio = (is_region_clipped * final_mask).sum() / final_mask.sum() gathered_low_clip = self.accelerator.gather_for_metrics(low_clip) self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) gathered_high_clip = self.accelerator.gather_for_metrics(high_clip) self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio) self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) return loss