bird-of-paradise commited on
Commit
cb481ca
·
verified ·
1 Parent(s): 06be75b

Upload 5 files

Browse files

Initial ReTool implementation - RL framework for tool-augmented LLM training

.gitattributes CHANGED
@@ -46,3 +46,5 @@ static/videos/shiba.mp4 filter=lfs diff=lfs merge=lfs -text
46
  static/videos/steve.mp4 filter=lfs diff=lfs merge=lfs -text
47
  static/videos/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
48
  static/videos/toby.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
46
  static/videos/steve.mp4 filter=lfs diff=lfs merge=lfs -text
47
  static/videos/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
48
  static/videos/toby.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ assets/aime_results.png filter=lfs diff=lfs merge=lfs -text
50
+ assets/retool_rollout_process.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,27 +1,227 @@
1
  ---
2
  title: ReTool Implementation
3
- emoji: 🧠
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: static
 
7
  pinned: false
8
  license: mit
9
- short_description: 'PyTorch implementation of ReTool: RL training framework for '
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- # Nerfies
13
 
14
- This is the repository that contains source code for the [Nerfies website](https://nerfies.github.io).
15
 
16
- If you find Nerfies useful for your work please cite:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ```
18
- @article{park2021nerfies
19
- author = {Park, Keunhong and Sinha, Utkarsh and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Seitz, Steven M. and Martin-Brualla, Ricardo},
20
- title = {Nerfies: Deformable Neural Radiance Fields},
21
- journal = {ICCV},
22
- year = {2021},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
  ```
25
 
26
- # Website License
27
- <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-sa/4.0/88x31.png" /></a><br />This work is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/">Creative Commons Attribution-ShareAlike 4.0 International License</a>.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: ReTool Implementation
3
+ emoji: 🔧
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: static
7
+ app_file: README.md
8
  pinned: false
9
  license: mit
10
+ tags:
11
+ - reinforcement-learning
12
+ - tool-use
13
+ - code-interpreter
14
+ - mathematical-reasoning
15
+ - rl-training
16
+ - ppo
17
+ - research-implementation
18
+ language: en
19
+ library_name: transformers
20
  ---
21
 
 
22
 
23
+ # ReTool: Reinforcement Learning for Strategic Tool Use in LLMs
24
 
25
+ A PyTorch implementation of **ReTool** from the paper ["ReTool: Reinforcement Learning for Strategic Tool Use in LLMs"](https://arxiv.org/abs/2504.11536) by Feng et al. (2025).
26
+
27
+ ReTool enhances long-form reasoning by integrating code interpreter execution into the RL training loop, enabling models to learn when and how to invoke computational tools for mathematical problem solving.
28
+
29
+ <div align="center">
30
+ <img src="assets/retool_rollout_process.png" alt="ReTool Rollout Process" width="80%">
31
+ <p><em>Figure 2: Comparison of standard text-based RL vs ReTool's code-integrated training process</em></p>
32
+ </div>
33
+
34
+ ## 🚀 Key Features
35
+
36
+ - **Multi-turn Generation**: Dynamic code execution during reasoning with KV-cache optimization
37
+ - **Strategic Tool Use**: Learns when and how to invoke code interpreters through RL
38
+ - **Interpreter Masking**: Excludes external tool outputs from gradient computation
39
+ - **Production Ready**: Built on HuggingFace Transformers with proper batching and distributed training support
40
+
41
+
42
+ ## 📊 Performance
43
+
44
+ <div align="center">
45
+ <img src="assets/aime_results.png" alt="AIME Results" width="70%">
46
+ <p><em>Figure 1: ReTool achieves 67% accuracy on AIME 2024, significantly outperforming text-based RL (40%)</em></p>
47
+ </div>
48
+
49
+ ## 🛠️ Installation
50
+
51
+ ```bash
52
+ git clone https://github.com/yourusername/retool-implementation.git
53
+ cd retool-implementation/scr
54
+ pip install -r requirements.txt
55
+ ```
56
+
57
+ ## 🚧 Current Status
58
+
59
+ **This is a research implementation based on the ReTool paper.** The core components are implemented but not yet fully tested.
60
+
61
+ ### What's Implemented ✅
62
+ - Multi-turn generation with KV-cache optimization
63
+ - Interpreter token masking for RL training
64
+ - Modified PPO loss computation
65
+ - Complete training pipeline structure
66
+ - Proper tensor handling and batching
67
+
68
+ ### What Needs Testing/Integration 🔧
69
+ - End-to-end training verification
70
+ - Code execution sandbox integration
71
+ - Edge case handling for truncated sequences
72
+ - Memory optimization for large models
73
+
74
+ ### For Researchers & Developers
75
+
76
+ This implementation serves as a foundation for:
77
+ - Understanding ReTool's architecture
78
+ - Building upon the multi-turn generation approach
79
+ - Integrating custom code execution environments
80
+ - Extending to other tool-use scenarios
81
+
82
+ ## 📊 Dataset Format
83
+
84
+ Your dataset should contain dictionaries with:
85
+
86
+ ```python
87
+ {
88
+ "prompt": "Solve this math problem: ...",
89
+ "answer": "42" # Ground truth for reward computation
90
+ }
91
+ ```
92
+
93
+ ## 🔍 How It Works
94
+
95
+ 1. **Multi-turn Generation**: Model generates reasoning step-by-step
96
+ 2. **Code Detection**: When `</code>` is generated, extract and execute code
97
+ 3. **Tool Integration**: Append `<interpreter>result</interpreter>` to context
98
+ 4. **Continued Reasoning**: Model continues with tool feedback
99
+ 5. **Reward Computation**: Binary reward based on final answer correctness
100
+ 6. **RL Training**: PPO updates exclude interpreter tokens from loss
101
+
102
+ ## ⚙️ Key Components
103
+
104
+ ### ReToolTrainer Class
105
+
106
+ - `_retool_generate_with_interpreter()`: Multi-turn generation with tool execution
107
+ - `_create_interpreter_mask()`: Creates masks for excluding tool outputs
108
+ - `_compute_loss()`: Modified PPO loss with interpreter masking
109
+ - `_compute_rewards_and_advantages()`: Binary reward computation
110
+
111
+ ### Configuration Options
112
+
113
+ ```python
114
+ trainer = ReToolTrainer(
115
+ # ... model and data ...
116
+ max_turns=10, # Maximum reasoning turns
117
+ temperature=0.7, # Generation temperature
118
+ max_completion_length=1024, # Max tokens per turn
119
+ mask_truncated_completions=True, # Handle incomplete sequences
120
+ )
121
+ ```
122
+
123
+ ## 💡 Usage Example (Conceptual)
124
+
125
+ ```python
126
+ from retool_trainer import ReToolTrainer
127
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
128
+
129
+ # This shows the intended API - full testing in progress
130
+ trainer = ReToolTrainer(
131
+ model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-32B-Instruct"),
132
+ processing_class=AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct"),
133
+ args=TrainingArguments(...),
134
+ train_dataset=your_math_dataset,
135
+ max_turns=10,
136
+ )
137
+
138
+ # trainer.train() # Full integration testing in progress
139
  ```
140
+
141
+ ## 📈 Results From Paper
142
+
143
+ - **AIME 2024**: 67% accuracy (vs 40% text-based RL)
144
+ - **AIME 2025**: 49.3% accuracy (vs 36.7% text-based RL)
145
+ - **Efficiency**: Converges in 400 steps vs 1080 for baseline
146
+ - **Token Efficiency**: 40% reduction in response length
147
+
148
+ ## 🚧 Limitations & TODOs
149
+
150
+ - [ ] Code execution sandbox integration
151
+ - [ ] Support for multiple reward functions
152
+ - [ ] Advanced error handling for malformed code
153
+ - [ ] Distributed training optimizations
154
+ - [ ] Tool selection beyond code interpreter
155
+
156
+
157
+
158
+
159
+
160
+ ## 📚 Citation
161
+
162
+ ```bibtex
163
+ @article{feng2025retool,
164
+ title={ReTool: Reinforcement Learning for Strategic Tool Use in LLMs},
165
+ author={Feng, Jiazhan and Huang, Shijue and Qu, Xingwei and Zhang, Ge and Qin, Yujia and Zhong, Baoquan and Jiang, Chengquan and Chi, Jinxin and Zhong, Wanjun},
166
+ journal={arXiv preprint arXiv:2504.11536},
167
+ year={2025}
168
  }
169
  ```
170
 
171
+ ## 📄 License
172
+
173
+ MIT License - see [LICENSE](LICENSE) file for details.
174
+
175
+ ## 🤝 Collaboration Welcome (But Not Required)
176
+
177
+ I'm perfectly happy working on this solo, but collaboration can be rewarding when there's mutual value and good fit.
178
+
179
+ ### 🛠️ Areas Where I'd Value Expertise
180
+
181
+ **Distributed Sandbox Engineering:**
182
+ - Asynchronous code execution environment with load balancing
183
+ - Worker pool architecture for parallel code execution
184
+ - Systems engineering and containerization expertise
185
+
186
+ **Dataset Engineering:**
187
+ - Mathematical reasoning dataset curation and validation
188
+ - Cold-start data pipeline design
189
+ - Quality control and formatting workflows
190
+
191
+ ### 🚀 Collaboration Approach
192
+
193
+ - **Start small:** Open an issue to discuss your approach first
194
+ - **Show, don't tell:** Small proof-of-concept before larger contributions
195
+ - **Quality focused:** Code review and documentation required
196
+ - **Clear attribution:** All substantial contributors get proper credit
197
+
198
+ ### 💰 The Compute Reality
199
+
200
+ **Full training requires significant resources:**
201
+ - ~8x A100s for complete AIME validation
202
+ - Currently exploring compute sponsorship options
203
+ - Happy to validate on smaller models first
204
+
205
+ ### 🎯 What I'm Looking For
206
+
207
+ - People who bring complementary skills (not just ML knowledge)
208
+ - Contributors who can work independently and deliver quality
209
+ - Collaborative mindset without drama or politics
210
+
211
+ **Interested?** Open an issue with your background and what you'd like to work on. Let's see if there's a good fit!
212
+
213
+ *No pressure though - I genuinely enjoy the solo research implementation process too.* 😊
214
+
215
+ ## 🙏 Acknowledgments
216
+
217
+ - Original paper authors for the ReTool framework
218
+ - HuggingFace team for the transformers library
219
+ - TRL team for GRPO implementation patterns
220
+
221
+ ---
222
+
223
+ <div align="center">
224
+ <strong>Built with ❤️ for advancing AI reasoning capabilities</strong>
225
+ </div>
226
+
227
+
assets/aime_results.png ADDED

Git LFS Details

  • SHA256: d78afcdfde8fc82d9767ca9e17e975fb6389fae153ab661be0f3c0ec92e8f45d
  • Pointer size: 131 Bytes
  • Size of remote file: 210 kB
assets/retool_rollout_process.png ADDED

Git LFS Details

  • SHA256: 50a0c13e6850c0c433e2200024c629a3fef587d5c095c48c59daa38283fac6e7
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB
src/requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML libraries
2
+ torch>=2.0.0
3
+ transformers>=4.36.0
4
+ datasets>=2.14.0
5
+ accelerate>=0.24.0
6
+
7
+ # TRL for base training utilities (optional, fallback to Trainer if not available)
8
+ trl>=0.7.0
9
+
10
+ # Utilities
11
+ packaging>=21.0
12
+ numpy>=1.21.0
13
+
14
+ # Profiling
15
+ profiling-decorator
16
+
17
+ # Optional but commonly used
18
+ wandb>=0.15.0 # for experiment tracking
19
+ tensorboard>=2.10.0 # alternative logging
src/retool_trainer.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Optional, Union
2
+ from collections import defaultdict
3
+ import re
4
+ import profiling_decorator
5
+
6
+ import datasets
7
+ import torch
8
+ import torch.utils.data
9
+ import transformers
10
+ from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
11
+ from datasets import Dataset, IterableDataset
12
+ from packaging import version
13
+ from torch import nn
14
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15
+ from torch.utils.data import DataLoader, Sampler
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoModelForSequenceClassification,
19
+ AutoTokenizer,
20
+ GenerationConfig,
21
+ PreTrainedModel,
22
+ PreTrainedTokenizerBase,
23
+ Trainer,
24
+ TrainerCallback,
25
+ is_wandb_available,
26
+ PreTrainedTokenizer,
27
+ )
28
+
29
+
30
+
31
+ class ReToolTrainer(Trainer): # Change this line
32
+
33
+ def __init__(
34
+ self,
35
+ model: Optional[PreTrainedModel] = None,
36
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
37
+ args: Optional[transformers.TrainingArguments] = None,
38
+ reward_funcs: Optional[list[Callable]] = None,
39
+ train_dataset: Optional[Dataset] = None,
40
+ eval_dataset: Optional[Dataset] = None,
41
+ # ReTool specific parameters - same as before
42
+ eos_id: Optional[int] = None,
43
+ interpreter_id: Optional[list[int]] = None,
44
+ code_id: Optional[list[int]] = None,
45
+ max_turns: int = 10,
46
+ max_completion_length: int = 1024,
47
+ temperature: float = 0.7,
48
+ top_p: float = 0.9,
49
+ top_k: int = 50,
50
+ min_p: Optional[float] = None,
51
+ mask_truncated_completions: bool = True,
52
+ **kwargs
53
+ ):
54
+ # Initialize parent Trainer (simpler call)
55
+ super().__init__(
56
+ model=model,
57
+ tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
58
+ args=args,
59
+ train_dataset=train_dataset,
60
+ eval_dataset=eval_dataset,
61
+ **kwargs
62
+ )
63
+
64
+
65
+ # Store processing_class for compatibility
66
+ self.processing_class = processing_class or self.tokenizer
67
+
68
+ # Add reward function handling (since Trainer doesn't have this)
69
+ self.reward_funcs = reward_funcs or [self._binary_reward_function]
70
+
71
+ # Rest of the ReTool-specific code stays exactly the same!
72
+ self.eos_id = eos_id or self.processing_class.eos_token_id
73
+
74
+
75
+ # ReTool specific attributes
76
+ self.eos_id = eos_id or self.processing_class.eos_token_id
77
+ self.interpreter_id = interpreter_id or self._get_interpreter_token_ids()
78
+ self.code_id = code_id or self._get_code_token_ids()
79
+ self.max_turns = max_turns
80
+ self.max_completion_length = max_completion_length
81
+ self.temperature = temperature
82
+ self.top_p = top_p
83
+ self.top_k = top_k
84
+ self.min_p = min_p
85
+ self.mask_truncated_completions = mask_truncated_completions
86
+
87
+ # ReTool specific logging
88
+ self.reward_func_names = ["binary_correctness"]
89
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
90
+ self._textual_logs = {
91
+ "prompt": [],
92
+ "completion": [],
93
+ "rewards": {"binary_correctness": []}
94
+ }
95
+
96
+ # Generation configuration for ReTool
97
+ self.generation_config = GenerationConfig(
98
+ max_new_tokens=50, # Per turn, not total
99
+ do_sample=True,
100
+ pad_token_id=self.processing_class.pad_token_id,
101
+ bos_token_id=self.processing_class.bos_token_id,
102
+ eos_token_id=[self.eos_id, self.code_id[1]], # Stop on EOS or </code>
103
+ temperature=self.temperature,
104
+ top_p=self.top_p,
105
+ top_k=self.top_k,
106
+ min_p=self.min_p,
107
+ return_dict_in_generate=True,
108
+ use_cache=True,
109
+ )
110
+
111
+
112
+ def _get_interpreter_token_ids(self) -> list[int]:
113
+ """Get token IDs for <interpreter> and </interpreter> tags."""
114
+ start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
115
+ end_token = self.processing_class.encode("</interpreter>", add_special_tokens=False)[0]
116
+ return [start_token, end_token]
117
+
118
+ def _get_code_token_ids(self) -> list[int]:
119
+ """Get token IDs for <code> and </code> tags."""
120
+ start_token = self.processing_class.encode("<code>", add_special_tokens=False)[0]
121
+ end_token = self.processing_class.encode("</code>", add_special_tokens=False)[0]
122
+ return [start_token, end_token]
123
+
124
+ def _binary_reward_function(self, prompts, completions, **kwargs) -> list[float]:
125
+ """Default binary reward function for mathematical correctness."""
126
+ rewards = []
127
+ ground_truths = kwargs.get('ground_truths', [None] * len(completions))
128
+
129
+ for completion, ground_truth in zip(completions, ground_truths):
130
+ if self._is_correct_answer(completion, ground_truth):
131
+ rewards.append(1.0)
132
+ else:
133
+ rewards.append(-1.0)
134
+ return rewards
135
+
136
+ def _execute_code(self, code_block: str) -> str:
137
+ """
138
+ Execute code in a sandbox environment.
139
+
140
+ TODO: Implement actual code execution sandbox.
141
+ For now, returns a placeholder.
142
+ """
143
+ # Placeholder implementation
144
+ return f"Executed: {code_block[:50]}... -> Result: 42"
145
+
146
+
147
+ def _check_equivalence(self, predicted, ground_truth):
148
+ """Simple equivalence check - you can make this more sophisticated later."""
149
+ # Simple string comparison for now
150
+ return str(predicted).strip() == str(ground_truth).strip()
151
+
152
+ def _is_correct_answer(self, completion_text, ground_truth):
153
+ import re
154
+ # Look for boxed answer
155
+ match = re.search(r'\\boxed\{([^}]+)\}', completion_text)
156
+ if match:
157
+ predicted = match.group(1)
158
+ return self._check_equivalence(predicted, ground_truth)
159
+ return False
160
+
161
+ def _compute_rewards_and_advantages(self, completions_text, ground_truths, device):
162
+ """Simplified reward and advantage computation for ReTool."""
163
+
164
+ # Compute binary rewards
165
+ rewards = []
166
+ for completion_text, ground_truth in zip(completions_text, ground_truths):
167
+ if self._is_correct_answer(completion_text, ground_truth):
168
+ rewards.append(1.0)
169
+ else:
170
+ rewards.append(-1.0)
171
+
172
+ # For now: advantages = rewards (skip group normalization)
173
+ advantages = torch.tensor(rewards, dtype=torch.float32, device=device)
174
+
175
+ return advantages
176
+
177
+
178
+ def _retool_generate_with_interpreter(
179
+ self,
180
+ prompt_ids_batch: torch.Tensor, # Full batch of prompts
181
+ attention_mask_batch: torch.Tensor, # Full batch of attention masks for prompts
182
+ #tokenizer: PreTrainedTokenizer, # use self.processiing_class for Tokenizer
183
+ eos_id: int, # True end-of-sequence token ID
184
+ interpreter_id: list[int], # [start_id, end_id]
185
+ code_id: list[int], # [start_id, end_id]
186
+ max_turns: int = 10
187
+ ) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]:
188
+
189
+ batch_size = prompt_ids_batch.size(0)
190
+ batch_completion = []
191
+ batch_interpreter_positions = []
192
+
193
+ for i in range(batch_size): # Process each item in the batch
194
+ # --- Initialization for the current sequence ---
195
+ current_input_id = prompt_ids_batch[i:i+1] # Initial input is the prompt
196
+ current_attention_mask = attention_mask_batch[i:i+1]
197
+ current_kv = None
198
+
199
+ # NEW: Track only the completion part (no prompt)
200
+ cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
201
+ interpreter_positions = []
202
+
203
+ for turn_idx in range(max_turns):
204
+ # --- Stage 1: LM generates text ---
205
+ model_outputs = self.model.generate(
206
+ input_ids=current_input_id,
207
+ attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
208
+ eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
209
+ past_key_values=current_kv,
210
+ generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
211
+ # max_new_tokens should be set in self.generation_config appropriately for a segment
212
+ )
213
+
214
+ # Update current_full_ids to the new complete sequence
215
+ current_full_ids = model_outputs.sequences
216
+
217
+ # Newly generated tokens by the LM in THIS step
218
+ completion_id = current_full_ids[:, current_input_id.size(1):]
219
+
220
+ # Add to completion tracking (excludes prompt)
221
+ cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
222
+
223
+ # Update current_input_id for the next generation step
224
+ # Update current_attention_mask: it was for (history + current_input_id),
225
+ # now append 1s for completion_id
226
+ current_attention_mask = torch.cat([
227
+ current_attention_mask,
228
+ torch.ones_like(completion_id)
229
+ ], dim=1)
230
+
231
+ current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
232
+
233
+ last_token_id = current_full_ids[0, -1].item()
234
+
235
+ if last_token_id == eos_id or turn_idx == max_turns - 1:
236
+ batch_completion.append(cumulative_completion_ids.squeeze(0))
237
+ batch_interpreter_positions.append(interpreter_positions) # Note: was batch_interpreter_positions[i] = ...
238
+ break
239
+
240
+ if last_token_id == code_id[1]: # Assuming code_id[1] is the specific ID for </code> last token
241
+ # --- Stage 2: Tool Execution ---
242
+ # Extract code from the generated sequence
243
+ full_text = self.processing_class.decode(current_full_ids[0])
244
+ code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
245
+ if code_match:
246
+ code_block = code_match.group(1)
247
+ interpreter_text = self._execute_code(code_block) # 👈 To do: code sandbox execution 👈
248
+ else:
249
+ interpreter_text = "Error: No code found"
250
+
251
+ formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
252
+
253
+ interpreter_feedback_id = self.processing_class(
254
+ formatted_feedback_text,
255
+ return_tensors="pt",
256
+ add_special_tokens=False
257
+ ).input_ids.to(current_full_ids.device)
258
+
259
+
260
+ # Record positions relative to cumulative_completion_ids *before* appending feedback
261
+ interpreter_start_idx = cumulative_completion_ids.size(1)
262
+ cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
263
+ interpreter_end_idx = cumulative_completion_ids.size(1) - 1
264
+ interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
265
+
266
+ # Update attention mask for the appended tool feedback
267
+ current_attention_mask = torch.cat([
268
+ current_attention_mask,
269
+ torch.ones_like(interpreter_feedback_id)
270
+ ], dim=1)
271
+
272
+ # Prepare for the next LM generation step:
273
+ # The model needs to "process" the tool_output_tokens to update its KV cache.
274
+ # The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
275
+ # `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
276
+ # The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
277
+ current_input_id = interpreter_feedback_id
278
+ # `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
279
+ # The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
280
+ else:
281
+ # LM stopped for a reason other than EOS or code_end` (e.g., max_new_tokens for the segment)
282
+ batch_completion.append(cumulative_completion_ids.squeeze(0))
283
+ batch_interpreter_positions.append(interpreter_positions)
284
+ # At the end, return full sequence (prompt + completion)
285
+ break
286
+ else: # Executed if the loop finished due to max_turns without a break
287
+ batch_completion.append(cumulative_completion_ids.squeeze(0))
288
+ batch_interpreter_positions.append(interpreter_positions)
289
+
290
+
291
+ # Pad sequences in the batch to the same length for returning a single tensor
292
+ # This is a common step if you started with a batch loop.
293
+ # Alternatively, this function could return a list of tensors if lengths vary.
294
+ # For now, assuming you'll handle batch padding outside or return a list.
295
+ # The return type `torch.LongTensor` implies a padded batch.
296
+ padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id)
297
+
298
+ return padded_sequences, batch_interpreter_positions
299
+
300
+
301
+
302
+ def _create_interpreter_mask(
303
+ self,
304
+ completion_ids: torch.Tensor,
305
+ interpreter_positions: list[list[tuple[int, int]]]
306
+ ) -> torch.Tensor:
307
+ """
308
+ Create interpreter mask from positions.
309
+
310
+ Args:
311
+ completion_ids: Tensor of shape (batch_size, seq_length)
312
+ interpreter_positions: List[List[Tuple[start_idx, end_idx]]]
313
+ - Indices are relative to completion_ids
314
+ - start_idx: inclusive, end_idx: INCLUSIVE (unlike typical Python slicing)
315
+
316
+ Returns:
317
+ interpreter_mask: Tensor of shape (batch_size, seq_length)
318
+ 1 = model-generated token, 0 = interpreter token
319
+ """
320
+ batch_size, seq_length = completion_ids.shape
321
+
322
+ # Initialize mask with all 1s (assume all tokens are model-generated)
323
+ interpreter_mask = torch.ones(batch_size, seq_length, dtype=torch.float, device=completion_ids.device)
324
+
325
+ # For each sequence in the batch
326
+ for batch_idx, positions_in_sequence in enumerate(interpreter_positions):
327
+ # For each interpreter section in this sequence
328
+ for start_idx, end_idx in positions_in_sequence:
329
+ # Clamp indices to valid range
330
+ start_idx = max(0, min(start_idx, seq_length - 1))
331
+ end_idx = max(0, min(end_idx, seq_length - 1))
332
+
333
+ # Zero out interpreter tokens (BOTH start and end inclusive)
334
+ if start_idx <= end_idx: # Changed from < to <=
335
+ interpreter_mask[batch_idx, start_idx:end_idx + 1] = 0 # Changed to end_idx + 1
336
+
337
+ return interpreter_mask
338
+
339
+
340
+ def _generate_and_score_completions(
341
+ self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
342
+ ) -> dict[str, Union[torch.Tensor, Any]]:
343
+
344
+ device = self.accelerator.device
345
+ mode = "train" if self.model.training else "eval"
346
+
347
+ prompts = [x["prompt"] for x in inputs]
348
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
349
+ prompt_inputs = self.processing_class(
350
+ text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
351
+ )
352
+ prompt_inputs = super()._prepare_inputs(prompt_inputs)
353
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
354
+
355
+ if self.max_prompt_length is not None:
356
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
357
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
358
+
359
+
360
+ # use custom multi-turn-w-tool-use Generate completions
361
+ completion_ids, interpreter_positions = self._retool_generate_with_interpreter(
362
+ prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config,
363
+ eos_id = self.eos_id, interpreter_id = self.interpreter_id, code_id = self.code_id
364
+ )
365
+
366
+
367
+ # Mask everything after the first EOS token
368
+ is_eos = completion_ids == self.processing_class.eos_token_id
369
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
370
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
371
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
372
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
373
+
374
+
375
+ # compute interpreter mask
376
+ interpreter_mask = self._create_interpreter_mask(completion_ids, interpreter_positions)
377
+
378
+
379
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
380
+ if self.mask_truncated_completions:
381
+ truncated_completions = ~is_eos.any(dim=1)
382
+ completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
383
+
384
+ # Concatenate prompt_mask with completion_mask for logit computation
385
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
386
+
387
+
388
+ # no need to return old_per_token_logps
389
+
390
+ # Extract ground truths from inputs
391
+ ground_truths = [x.get("answer") for x in inputs] # Adjust key name as needed
392
+
393
+ # Decode completions for reward computation
394
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
395
+
396
+ # Compute rewards and advantages
397
+ advantages = self._compute_rewards_and_advantages(
398
+ completions_text,
399
+ ground_truths,
400
+ device=device
401
+ )
402
+
403
+
404
+ # Log the metrics
405
+ if mode == "train":
406
+ self.state.num_input_tokens_seen += attention_mask.sum().item() # Skip gather
407
+ self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
408
+
409
+ # Log completion lengths
410
+ completion_lengths = completion_mask.sum(1) # Skip gather
411
+ self._metrics[mode]["completions/mean_length"].append(completion_lengths.float().mean().item())
412
+ self._metrics[mode]["completions/min_length"].append(completion_lengths.float().min().item())
413
+ self._metrics[mode]["completions/max_length"].append(completion_lengths.float().max().item())
414
+
415
+ # Log terminated sequences
416
+ terminated_with_eos = is_eos.any(dim=1) # Skip gather
417
+ term_completion_lengths = completion_lengths[terminated_with_eos]
418
+ clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
419
+ self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
420
+
421
+ if len(term_completion_lengths) == 0:
422
+ term_completion_lengths = torch.zeros(1, device=device)
423
+
424
+ self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
425
+
426
+ # Log rewards (simplified for single reward function)
427
+ advantages_tensor = advantages
428
+ self._metrics[mode]["rewards/binary_correctness/mean"].append(advantages_tensor.mean().item())
429
+ self._metrics[mode]["rewards/binary_correctness/std"].append(advantages_tensor.std().item())
430
+
431
+
432
+ # Log texts for debugging
433
+ self._textual_logs["prompt"].extend(prompts_text)
434
+ self._textual_logs["completion"].extend(completions_text)
435
+ self._textual_logs["rewards"]["binary_correctness"].extend(advantages.tolist())
436
+
437
+ return {
438
+ "prompt_ids": prompt_ids,
439
+ "prompt_mask": prompt_mask,
440
+ "completion_ids": completion_ids,
441
+ "completion_mask": completion_mask,
442
+ "interpreter_mask": interpreter_mask,
443
+ "advantages": advantages
444
+ }
445
+
446
+
447
+ # Get the per-token log probabilities for the completions for the model and the reference model
448
+ @profiling_decorator
449
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
450
+ batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
451
+ all_logps = []
452
+ for i in range(0, input_ids.size(0), batch_size):
453
+ input_ids_batch = input_ids[i : i + batch_size]
454
+ attention_mask_batch = attention_mask[i : i + batch_size]
455
+
456
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
457
+ logits = model(
458
+ input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
459
+ ).logits
460
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
461
+ input_ids_batch = input_ids_batch[:, -logits_to_keep:]
462
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
463
+ # See https://github.com/huggingface/trl/issues/2770
464
+ logits = logits[:, -logits_to_keep:]
465
+ # Divide logits by sampling temperature.
466
+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
467
+ logits = logits / self.temperature
468
+ logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
469
+ all_logps.append(logps)
470
+ return torch.cat(all_logps, dim=0)
471
+
472
+
473
+ @staticmethod
474
+ def selective_log_softmax(logits, index):
475
+ """
476
+ A memory-efficient implementation of the common `log_softmax -> gather` operation.
477
+
478
+ This function is equivalent to the following naive implementation:
479
+ ```python
480
+ logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
481
+ ```
482
+
483
+ Args:
484
+ logits (`torch.Tensor`):
485
+ Logits tensor of shape `(..., num_classes)`.
486
+ index (`torch.Tensor`):
487
+ Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
488
+
489
+ Returns:
490
+ `torch.Tensor`:
491
+ Gathered log probabilities with the same shape as `index`.
492
+ """
493
+ if logits.dtype in [torch.float32, torch.float64]:
494
+ selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
495
+ # loop to reduce peak mem consumption
496
+ logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
497
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
498
+ else:
499
+ # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
500
+ per_token_logps = []
501
+ for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
502
+ row_logps = F.log_softmax(row_logits, dim=-1)
503
+ row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
504
+ per_token_logps.append(row_per_token_logps)
505
+ per_token_logps = torch.stack(per_token_logps)
506
+ return per_token_logps
507
+
508
+
509
+ def _compute_loss(self, model, inputs):
510
+ # Compute the per-token log probabilities for the model
511
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
512
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
513
+
514
+ # Added for ReTool Trainer
515
+ interpreter_mask = inputs["interpreter_mask"]
516
+ final_mask = interpreter_mask * completion_mask
517
+
518
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
519
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
520
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
521
+
522
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
523
+
524
+ with torch.no_grad():
525
+ ref_per_token_logps = self._get_per_token_logps(
526
+ self.ref_model, input_ids, attention_mask, logits_to_keep
527
+ )
528
+ # Compute the KL divergence between the model and the reference model
529
+ if self.beta != 0.0:
530
+ per_token_kl = (
531
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
532
+ )
533
+
534
+ # Compute the loss
535
+ advantages = inputs["advantages"]
536
+
537
+ old_per_token_logps = ref_per_token_logps
538
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
539
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
540
+
541
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
542
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
543
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
544
+ if self.beta != 0.0:
545
+ per_token_loss = per_token_loss + self.beta * per_token_kl
546
+
547
+
548
+ # For PPO loss
549
+ masked_loss = per_token_loss * final_mask
550
+ total_valid_tokens = final_mask.sum() + 1e-8 # Avoid division by zero
551
+ loss = masked_loss.sum() / total_valid_tokens
552
+
553
+ """ --- """
554
+
555
+ # Log the metrics
556
+ mode = "train" if self.model.training else "eval"
557
+
558
+ if self.beta != 0.0:
559
+ mean_kl = (per_token_kl * final_mask).sum() / final_mask.sum()
560
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
561
+
562
+ # Compute the clipped probability ratios
563
+ is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
564
+ is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
565
+ is_region_clipped = is_low_clipped | is_high_clipped
566
+
567
+ low_clip = (is_low_clipped * final_mask).sum() / final_mask.sum()
568
+ high_clip = (is_high_clipped * final_mask).sum() / final_mask.sum()
569
+ clip_ratio = (is_region_clipped * final_mask).sum() / final_mask.sum()
570
+
571
+ gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
572
+ self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
573
+ self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
574
+ gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
575
+ self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
576
+ self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
577
+ gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
578
+ self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
579
+ return loss