Upload 5 files
Browse filesInitial ReTool implementation - RL framework for tool-augmented LLM training
- .gitattributes +2 -0
- README.md +214 -14
- assets/aime_results.png +3 -0
- assets/retool_rollout_process.png +3 -0
- src/requirements.txt +19 -0
- src/retool_trainer.py +579 -0
.gitattributes
CHANGED
@@ -46,3 +46,5 @@ static/videos/shiba.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
46 |
static/videos/steve.mp4 filter=lfs diff=lfs merge=lfs -text
|
47 |
static/videos/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
|
48 |
static/videos/toby.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
46 |
static/videos/steve.mp4 filter=lfs diff=lfs merge=lfs -text
|
47 |
static/videos/teaser.mp4 filter=lfs diff=lfs merge=lfs -text
|
48 |
static/videos/toby.mp4 filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/aime_results.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/retool_rollout_process.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,27 +1,227 @@
|
|
1 |
---
|
2 |
title: ReTool Implementation
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: static
|
|
|
7 |
pinned: false
|
8 |
license: mit
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
# Nerfies
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
```
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
}
|
24 |
```
|
25 |
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: ReTool Implementation
|
3 |
+
emoji: 🔧
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: static
|
7 |
+
app_file: README.md
|
8 |
pinned: false
|
9 |
license: mit
|
10 |
+
tags:
|
11 |
+
- reinforcement-learning
|
12 |
+
- tool-use
|
13 |
+
- code-interpreter
|
14 |
+
- mathematical-reasoning
|
15 |
+
- rl-training
|
16 |
+
- ppo
|
17 |
+
- research-implementation
|
18 |
+
language: en
|
19 |
+
library_name: transformers
|
20 |
---
|
21 |
|
|
|
22 |
|
23 |
+
# ReTool: Reinforcement Learning for Strategic Tool Use in LLMs
|
24 |
|
25 |
+
A PyTorch implementation of **ReTool** from the paper ["ReTool: Reinforcement Learning for Strategic Tool Use in LLMs"](https://arxiv.org/abs/2504.11536) by Feng et al. (2025).
|
26 |
+
|
27 |
+
ReTool enhances long-form reasoning by integrating code interpreter execution into the RL training loop, enabling models to learn when and how to invoke computational tools for mathematical problem solving.
|
28 |
+
|
29 |
+
<div align="center">
|
30 |
+
<img src="assets/retool_rollout_process.png" alt="ReTool Rollout Process" width="80%">
|
31 |
+
<p><em>Figure 2: Comparison of standard text-based RL vs ReTool's code-integrated training process</em></p>
|
32 |
+
</div>
|
33 |
+
|
34 |
+
## 🚀 Key Features
|
35 |
+
|
36 |
+
- **Multi-turn Generation**: Dynamic code execution during reasoning with KV-cache optimization
|
37 |
+
- **Strategic Tool Use**: Learns when and how to invoke code interpreters through RL
|
38 |
+
- **Interpreter Masking**: Excludes external tool outputs from gradient computation
|
39 |
+
- **Production Ready**: Built on HuggingFace Transformers with proper batching and distributed training support
|
40 |
+
|
41 |
+
|
42 |
+
## 📊 Performance
|
43 |
+
|
44 |
+
<div align="center">
|
45 |
+
<img src="assets/aime_results.png" alt="AIME Results" width="70%">
|
46 |
+
<p><em>Figure 1: ReTool achieves 67% accuracy on AIME 2024, significantly outperforming text-based RL (40%)</em></p>
|
47 |
+
</div>
|
48 |
+
|
49 |
+
## 🛠️ Installation
|
50 |
+
|
51 |
+
```bash
|
52 |
+
git clone https://github.com/yourusername/retool-implementation.git
|
53 |
+
cd retool-implementation/scr
|
54 |
+
pip install -r requirements.txt
|
55 |
+
```
|
56 |
+
|
57 |
+
## 🚧 Current Status
|
58 |
+
|
59 |
+
**This is a research implementation based on the ReTool paper.** The core components are implemented but not yet fully tested.
|
60 |
+
|
61 |
+
### What's Implemented ✅
|
62 |
+
- Multi-turn generation with KV-cache optimization
|
63 |
+
- Interpreter token masking for RL training
|
64 |
+
- Modified PPO loss computation
|
65 |
+
- Complete training pipeline structure
|
66 |
+
- Proper tensor handling and batching
|
67 |
+
|
68 |
+
### What Needs Testing/Integration 🔧
|
69 |
+
- End-to-end training verification
|
70 |
+
- Code execution sandbox integration
|
71 |
+
- Edge case handling for truncated sequences
|
72 |
+
- Memory optimization for large models
|
73 |
+
|
74 |
+
### For Researchers & Developers
|
75 |
+
|
76 |
+
This implementation serves as a foundation for:
|
77 |
+
- Understanding ReTool's architecture
|
78 |
+
- Building upon the multi-turn generation approach
|
79 |
+
- Integrating custom code execution environments
|
80 |
+
- Extending to other tool-use scenarios
|
81 |
+
|
82 |
+
## 📊 Dataset Format
|
83 |
+
|
84 |
+
Your dataset should contain dictionaries with:
|
85 |
+
|
86 |
+
```python
|
87 |
+
{
|
88 |
+
"prompt": "Solve this math problem: ...",
|
89 |
+
"answer": "42" # Ground truth for reward computation
|
90 |
+
}
|
91 |
+
```
|
92 |
+
|
93 |
+
## 🔍 How It Works
|
94 |
+
|
95 |
+
1. **Multi-turn Generation**: Model generates reasoning step-by-step
|
96 |
+
2. **Code Detection**: When `</code>` is generated, extract and execute code
|
97 |
+
3. **Tool Integration**: Append `<interpreter>result</interpreter>` to context
|
98 |
+
4. **Continued Reasoning**: Model continues with tool feedback
|
99 |
+
5. **Reward Computation**: Binary reward based on final answer correctness
|
100 |
+
6. **RL Training**: PPO updates exclude interpreter tokens from loss
|
101 |
+
|
102 |
+
## ⚙️ Key Components
|
103 |
+
|
104 |
+
### ReToolTrainer Class
|
105 |
+
|
106 |
+
- `_retool_generate_with_interpreter()`: Multi-turn generation with tool execution
|
107 |
+
- `_create_interpreter_mask()`: Creates masks for excluding tool outputs
|
108 |
+
- `_compute_loss()`: Modified PPO loss with interpreter masking
|
109 |
+
- `_compute_rewards_and_advantages()`: Binary reward computation
|
110 |
+
|
111 |
+
### Configuration Options
|
112 |
+
|
113 |
+
```python
|
114 |
+
trainer = ReToolTrainer(
|
115 |
+
# ... model and data ...
|
116 |
+
max_turns=10, # Maximum reasoning turns
|
117 |
+
temperature=0.7, # Generation temperature
|
118 |
+
max_completion_length=1024, # Max tokens per turn
|
119 |
+
mask_truncated_completions=True, # Handle incomplete sequences
|
120 |
+
)
|
121 |
+
```
|
122 |
+
|
123 |
+
## 💡 Usage Example (Conceptual)
|
124 |
+
|
125 |
+
```python
|
126 |
+
from retool_trainer import ReToolTrainer
|
127 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
128 |
+
|
129 |
+
# This shows the intended API - full testing in progress
|
130 |
+
trainer = ReToolTrainer(
|
131 |
+
model=AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-32B-Instruct"),
|
132 |
+
processing_class=AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct"),
|
133 |
+
args=TrainingArguments(...),
|
134 |
+
train_dataset=your_math_dataset,
|
135 |
+
max_turns=10,
|
136 |
+
)
|
137 |
+
|
138 |
+
# trainer.train() # Full integration testing in progress
|
139 |
```
|
140 |
+
|
141 |
+
## 📈 Results From Paper
|
142 |
+
|
143 |
+
- **AIME 2024**: 67% accuracy (vs 40% text-based RL)
|
144 |
+
- **AIME 2025**: 49.3% accuracy (vs 36.7% text-based RL)
|
145 |
+
- **Efficiency**: Converges in 400 steps vs 1080 for baseline
|
146 |
+
- **Token Efficiency**: 40% reduction in response length
|
147 |
+
|
148 |
+
## 🚧 Limitations & TODOs
|
149 |
+
|
150 |
+
- [ ] Code execution sandbox integration
|
151 |
+
- [ ] Support for multiple reward functions
|
152 |
+
- [ ] Advanced error handling for malformed code
|
153 |
+
- [ ] Distributed training optimizations
|
154 |
+
- [ ] Tool selection beyond code interpreter
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
## 📚 Citation
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@article{feng2025retool,
|
164 |
+
title={ReTool: Reinforcement Learning for Strategic Tool Use in LLMs},
|
165 |
+
author={Feng, Jiazhan and Huang, Shijue and Qu, Xingwei and Zhang, Ge and Qin, Yujia and Zhong, Baoquan and Jiang, Chengquan and Chi, Jinxin and Zhong, Wanjun},
|
166 |
+
journal={arXiv preprint arXiv:2504.11536},
|
167 |
+
year={2025}
|
168 |
}
|
169 |
```
|
170 |
|
171 |
+
## 📄 License
|
172 |
+
|
173 |
+
MIT License - see [LICENSE](LICENSE) file for details.
|
174 |
+
|
175 |
+
## 🤝 Collaboration Welcome (But Not Required)
|
176 |
+
|
177 |
+
I'm perfectly happy working on this solo, but collaboration can be rewarding when there's mutual value and good fit.
|
178 |
+
|
179 |
+
### 🛠️ Areas Where I'd Value Expertise
|
180 |
+
|
181 |
+
**Distributed Sandbox Engineering:**
|
182 |
+
- Asynchronous code execution environment with load balancing
|
183 |
+
- Worker pool architecture for parallel code execution
|
184 |
+
- Systems engineering and containerization expertise
|
185 |
+
|
186 |
+
**Dataset Engineering:**
|
187 |
+
- Mathematical reasoning dataset curation and validation
|
188 |
+
- Cold-start data pipeline design
|
189 |
+
- Quality control and formatting workflows
|
190 |
+
|
191 |
+
### 🚀 Collaboration Approach
|
192 |
+
|
193 |
+
- **Start small:** Open an issue to discuss your approach first
|
194 |
+
- **Show, don't tell:** Small proof-of-concept before larger contributions
|
195 |
+
- **Quality focused:** Code review and documentation required
|
196 |
+
- **Clear attribution:** All substantial contributors get proper credit
|
197 |
+
|
198 |
+
### 💰 The Compute Reality
|
199 |
+
|
200 |
+
**Full training requires significant resources:**
|
201 |
+
- ~8x A100s for complete AIME validation
|
202 |
+
- Currently exploring compute sponsorship options
|
203 |
+
- Happy to validate on smaller models first
|
204 |
+
|
205 |
+
### 🎯 What I'm Looking For
|
206 |
+
|
207 |
+
- People who bring complementary skills (not just ML knowledge)
|
208 |
+
- Contributors who can work independently and deliver quality
|
209 |
+
- Collaborative mindset without drama or politics
|
210 |
+
|
211 |
+
**Interested?** Open an issue with your background and what you'd like to work on. Let's see if there's a good fit!
|
212 |
+
|
213 |
+
*No pressure though - I genuinely enjoy the solo research implementation process too.* 😊
|
214 |
+
|
215 |
+
## 🙏 Acknowledgments
|
216 |
+
|
217 |
+
- Original paper authors for the ReTool framework
|
218 |
+
- HuggingFace team for the transformers library
|
219 |
+
- TRL team for GRPO implementation patterns
|
220 |
+
|
221 |
+
---
|
222 |
+
|
223 |
+
<div align="center">
|
224 |
+
<strong>Built with ❤️ for advancing AI reasoning capabilities</strong>
|
225 |
+
</div>
|
226 |
+
|
227 |
+
|
assets/aime_results.png
ADDED
![]() |
Git LFS Details
|
assets/retool_rollout_process.png
ADDED
![]() |
Git LFS Details
|
src/requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core ML libraries
|
2 |
+
torch>=2.0.0
|
3 |
+
transformers>=4.36.0
|
4 |
+
datasets>=2.14.0
|
5 |
+
accelerate>=0.24.0
|
6 |
+
|
7 |
+
# TRL for base training utilities (optional, fallback to Trainer if not available)
|
8 |
+
trl>=0.7.0
|
9 |
+
|
10 |
+
# Utilities
|
11 |
+
packaging>=21.0
|
12 |
+
numpy>=1.21.0
|
13 |
+
|
14 |
+
# Profiling
|
15 |
+
profiling-decorator
|
16 |
+
|
17 |
+
# Optional but commonly used
|
18 |
+
wandb>=0.15.0 # for experiment tracking
|
19 |
+
tensorboard>=2.10.0 # alternative logging
|
src/retool_trainer.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Optional, Union
|
2 |
+
from collections import defaultdict
|
3 |
+
import re
|
4 |
+
import profiling_decorator
|
5 |
+
|
6 |
+
import datasets
|
7 |
+
import torch
|
8 |
+
import torch.utils.data
|
9 |
+
import transformers
|
10 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
11 |
+
from datasets import Dataset, IterableDataset
|
12 |
+
from packaging import version
|
13 |
+
from torch import nn
|
14 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
15 |
+
from torch.utils.data import DataLoader, Sampler
|
16 |
+
from transformers import (
|
17 |
+
AutoModelForCausalLM,
|
18 |
+
AutoModelForSequenceClassification,
|
19 |
+
AutoTokenizer,
|
20 |
+
GenerationConfig,
|
21 |
+
PreTrainedModel,
|
22 |
+
PreTrainedTokenizerBase,
|
23 |
+
Trainer,
|
24 |
+
TrainerCallback,
|
25 |
+
is_wandb_available,
|
26 |
+
PreTrainedTokenizer,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
class ReToolTrainer(Trainer): # Change this line
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
model: Optional[PreTrainedModel] = None,
|
36 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
37 |
+
args: Optional[transformers.TrainingArguments] = None,
|
38 |
+
reward_funcs: Optional[list[Callable]] = None,
|
39 |
+
train_dataset: Optional[Dataset] = None,
|
40 |
+
eval_dataset: Optional[Dataset] = None,
|
41 |
+
# ReTool specific parameters - same as before
|
42 |
+
eos_id: Optional[int] = None,
|
43 |
+
interpreter_id: Optional[list[int]] = None,
|
44 |
+
code_id: Optional[list[int]] = None,
|
45 |
+
max_turns: int = 10,
|
46 |
+
max_completion_length: int = 1024,
|
47 |
+
temperature: float = 0.7,
|
48 |
+
top_p: float = 0.9,
|
49 |
+
top_k: int = 50,
|
50 |
+
min_p: Optional[float] = None,
|
51 |
+
mask_truncated_completions: bool = True,
|
52 |
+
**kwargs
|
53 |
+
):
|
54 |
+
# Initialize parent Trainer (simpler call)
|
55 |
+
super().__init__(
|
56 |
+
model=model,
|
57 |
+
tokenizer=processing_class, # Note: Trainer uses 'tokenizer', not 'processing_class'
|
58 |
+
args=args,
|
59 |
+
train_dataset=train_dataset,
|
60 |
+
eval_dataset=eval_dataset,
|
61 |
+
**kwargs
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
# Store processing_class for compatibility
|
66 |
+
self.processing_class = processing_class or self.tokenizer
|
67 |
+
|
68 |
+
# Add reward function handling (since Trainer doesn't have this)
|
69 |
+
self.reward_funcs = reward_funcs or [self._binary_reward_function]
|
70 |
+
|
71 |
+
# Rest of the ReTool-specific code stays exactly the same!
|
72 |
+
self.eos_id = eos_id or self.processing_class.eos_token_id
|
73 |
+
|
74 |
+
|
75 |
+
# ReTool specific attributes
|
76 |
+
self.eos_id = eos_id or self.processing_class.eos_token_id
|
77 |
+
self.interpreter_id = interpreter_id or self._get_interpreter_token_ids()
|
78 |
+
self.code_id = code_id or self._get_code_token_ids()
|
79 |
+
self.max_turns = max_turns
|
80 |
+
self.max_completion_length = max_completion_length
|
81 |
+
self.temperature = temperature
|
82 |
+
self.top_p = top_p
|
83 |
+
self.top_k = top_k
|
84 |
+
self.min_p = min_p
|
85 |
+
self.mask_truncated_completions = mask_truncated_completions
|
86 |
+
|
87 |
+
# ReTool specific logging
|
88 |
+
self.reward_func_names = ["binary_correctness"]
|
89 |
+
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
90 |
+
self._textual_logs = {
|
91 |
+
"prompt": [],
|
92 |
+
"completion": [],
|
93 |
+
"rewards": {"binary_correctness": []}
|
94 |
+
}
|
95 |
+
|
96 |
+
# Generation configuration for ReTool
|
97 |
+
self.generation_config = GenerationConfig(
|
98 |
+
max_new_tokens=50, # Per turn, not total
|
99 |
+
do_sample=True,
|
100 |
+
pad_token_id=self.processing_class.pad_token_id,
|
101 |
+
bos_token_id=self.processing_class.bos_token_id,
|
102 |
+
eos_token_id=[self.eos_id, self.code_id[1]], # Stop on EOS or </code>
|
103 |
+
temperature=self.temperature,
|
104 |
+
top_p=self.top_p,
|
105 |
+
top_k=self.top_k,
|
106 |
+
min_p=self.min_p,
|
107 |
+
return_dict_in_generate=True,
|
108 |
+
use_cache=True,
|
109 |
+
)
|
110 |
+
|
111 |
+
|
112 |
+
def _get_interpreter_token_ids(self) -> list[int]:
|
113 |
+
"""Get token IDs for <interpreter> and </interpreter> tags."""
|
114 |
+
start_token = self.processing_class.encode("<interpreter>", add_special_tokens=False)[0]
|
115 |
+
end_token = self.processing_class.encode("</interpreter>", add_special_tokens=False)[0]
|
116 |
+
return [start_token, end_token]
|
117 |
+
|
118 |
+
def _get_code_token_ids(self) -> list[int]:
|
119 |
+
"""Get token IDs for <code> and </code> tags."""
|
120 |
+
start_token = self.processing_class.encode("<code>", add_special_tokens=False)[0]
|
121 |
+
end_token = self.processing_class.encode("</code>", add_special_tokens=False)[0]
|
122 |
+
return [start_token, end_token]
|
123 |
+
|
124 |
+
def _binary_reward_function(self, prompts, completions, **kwargs) -> list[float]:
|
125 |
+
"""Default binary reward function for mathematical correctness."""
|
126 |
+
rewards = []
|
127 |
+
ground_truths = kwargs.get('ground_truths', [None] * len(completions))
|
128 |
+
|
129 |
+
for completion, ground_truth in zip(completions, ground_truths):
|
130 |
+
if self._is_correct_answer(completion, ground_truth):
|
131 |
+
rewards.append(1.0)
|
132 |
+
else:
|
133 |
+
rewards.append(-1.0)
|
134 |
+
return rewards
|
135 |
+
|
136 |
+
def _execute_code(self, code_block: str) -> str:
|
137 |
+
"""
|
138 |
+
Execute code in a sandbox environment.
|
139 |
+
|
140 |
+
TODO: Implement actual code execution sandbox.
|
141 |
+
For now, returns a placeholder.
|
142 |
+
"""
|
143 |
+
# Placeholder implementation
|
144 |
+
return f"Executed: {code_block[:50]}... -> Result: 42"
|
145 |
+
|
146 |
+
|
147 |
+
def _check_equivalence(self, predicted, ground_truth):
|
148 |
+
"""Simple equivalence check - you can make this more sophisticated later."""
|
149 |
+
# Simple string comparison for now
|
150 |
+
return str(predicted).strip() == str(ground_truth).strip()
|
151 |
+
|
152 |
+
def _is_correct_answer(self, completion_text, ground_truth):
|
153 |
+
import re
|
154 |
+
# Look for boxed answer
|
155 |
+
match = re.search(r'\\boxed\{([^}]+)\}', completion_text)
|
156 |
+
if match:
|
157 |
+
predicted = match.group(1)
|
158 |
+
return self._check_equivalence(predicted, ground_truth)
|
159 |
+
return False
|
160 |
+
|
161 |
+
def _compute_rewards_and_advantages(self, completions_text, ground_truths, device):
|
162 |
+
"""Simplified reward and advantage computation for ReTool."""
|
163 |
+
|
164 |
+
# Compute binary rewards
|
165 |
+
rewards = []
|
166 |
+
for completion_text, ground_truth in zip(completions_text, ground_truths):
|
167 |
+
if self._is_correct_answer(completion_text, ground_truth):
|
168 |
+
rewards.append(1.0)
|
169 |
+
else:
|
170 |
+
rewards.append(-1.0)
|
171 |
+
|
172 |
+
# For now: advantages = rewards (skip group normalization)
|
173 |
+
advantages = torch.tensor(rewards, dtype=torch.float32, device=device)
|
174 |
+
|
175 |
+
return advantages
|
176 |
+
|
177 |
+
|
178 |
+
def _retool_generate_with_interpreter(
|
179 |
+
self,
|
180 |
+
prompt_ids_batch: torch.Tensor, # Full batch of prompts
|
181 |
+
attention_mask_batch: torch.Tensor, # Full batch of attention masks for prompts
|
182 |
+
#tokenizer: PreTrainedTokenizer, # use self.processiing_class for Tokenizer
|
183 |
+
eos_id: int, # True end-of-sequence token ID
|
184 |
+
interpreter_id: list[int], # [start_id, end_id]
|
185 |
+
code_id: list[int], # [start_id, end_id]
|
186 |
+
max_turns: int = 10
|
187 |
+
) -> tuple[torch.LongTensor, list[list[tuple[int, int]]]]:
|
188 |
+
|
189 |
+
batch_size = prompt_ids_batch.size(0)
|
190 |
+
batch_completion = []
|
191 |
+
batch_interpreter_positions = []
|
192 |
+
|
193 |
+
for i in range(batch_size): # Process each item in the batch
|
194 |
+
# --- Initialization for the current sequence ---
|
195 |
+
current_input_id = prompt_ids_batch[i:i+1] # Initial input is the prompt
|
196 |
+
current_attention_mask = attention_mask_batch[i:i+1]
|
197 |
+
current_kv = None
|
198 |
+
|
199 |
+
# NEW: Track only the completion part (no prompt)
|
200 |
+
cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)
|
201 |
+
interpreter_positions = []
|
202 |
+
|
203 |
+
for turn_idx in range(max_turns):
|
204 |
+
# --- Stage 1: LM generates text ---
|
205 |
+
model_outputs = self.model.generate(
|
206 |
+
input_ids=current_input_id,
|
207 |
+
attention_mask=current_attention_mask, # This mask is for (history in KV cache + current_input_id)
|
208 |
+
eos_token_id=[eos_id, code_id[1]], # code_id[1] is assumed to be </code>'s last token ID
|
209 |
+
past_key_values=current_kv,
|
210 |
+
generation_config=self.generation_config, # Ensure this has return_dict_in_generate=True, use_cache=True
|
211 |
+
# max_new_tokens should be set in self.generation_config appropriately for a segment
|
212 |
+
)
|
213 |
+
|
214 |
+
# Update current_full_ids to the new complete sequence
|
215 |
+
current_full_ids = model_outputs.sequences
|
216 |
+
|
217 |
+
# Newly generated tokens by the LM in THIS step
|
218 |
+
completion_id = current_full_ids[:, current_input_id.size(1):]
|
219 |
+
|
220 |
+
# Add to completion tracking (excludes prompt)
|
221 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, completion_id], dim=1)
|
222 |
+
|
223 |
+
# Update current_input_id for the next generation step
|
224 |
+
# Update current_attention_mask: it was for (history + current_input_id),
|
225 |
+
# now append 1s for completion_id
|
226 |
+
current_attention_mask = torch.cat([
|
227 |
+
current_attention_mask,
|
228 |
+
torch.ones_like(completion_id)
|
229 |
+
], dim=1)
|
230 |
+
|
231 |
+
current_kv = model_outputs.past_key_values # Cache for the new current_full_ids
|
232 |
+
|
233 |
+
last_token_id = current_full_ids[0, -1].item()
|
234 |
+
|
235 |
+
if last_token_id == eos_id or turn_idx == max_turns - 1:
|
236 |
+
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
237 |
+
batch_interpreter_positions.append(interpreter_positions) # Note: was batch_interpreter_positions[i] = ...
|
238 |
+
break
|
239 |
+
|
240 |
+
if last_token_id == code_id[1]: # Assuming code_id[1] is the specific ID for </code> last token
|
241 |
+
# --- Stage 2: Tool Execution ---
|
242 |
+
# Extract code from the generated sequence
|
243 |
+
full_text = self.processing_class.decode(current_full_ids[0])
|
244 |
+
code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)
|
245 |
+
if code_match:
|
246 |
+
code_block = code_match.group(1)
|
247 |
+
interpreter_text = self._execute_code(code_block) # 👈 To do: code sandbox execution 👈
|
248 |
+
else:
|
249 |
+
interpreter_text = "Error: No code found"
|
250 |
+
|
251 |
+
formatted_feedback_text = f"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}"
|
252 |
+
|
253 |
+
interpreter_feedback_id = self.processing_class(
|
254 |
+
formatted_feedback_text,
|
255 |
+
return_tensors="pt",
|
256 |
+
add_special_tokens=False
|
257 |
+
).input_ids.to(current_full_ids.device)
|
258 |
+
|
259 |
+
|
260 |
+
# Record positions relative to cumulative_completion_ids *before* appending feedback
|
261 |
+
interpreter_start_idx = cumulative_completion_ids.size(1)
|
262 |
+
cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_feedback_id], dim=1) # Use cumulative, not current
|
263 |
+
interpreter_end_idx = cumulative_completion_ids.size(1) - 1
|
264 |
+
interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))
|
265 |
+
|
266 |
+
# Update attention mask for the appended tool feedback
|
267 |
+
current_attention_mask = torch.cat([
|
268 |
+
current_attention_mask,
|
269 |
+
torch.ones_like(interpreter_feedback_id)
|
270 |
+
], dim=1)
|
271 |
+
|
272 |
+
# Prepare for the next LM generation step:
|
273 |
+
# The model needs to "process" the tool_output_tokens to update its KV cache.
|
274 |
+
# The `current_input_id` for the next generate call will be `interpreter_feedback_id`.
|
275 |
+
# `current_kv` already holds the cache for `current_full_ids` *before* the tool feedback was appended.
|
276 |
+
# The `current_attention_mask` now correctly covers `current_full_ids` (which includes tool feedback).
|
277 |
+
current_input_id = interpreter_feedback_id
|
278 |
+
# `current_kv` is correct (it's for the prefix before `interpreter_feedback_id`).
|
279 |
+
# The next `model.generate` call will use this `current_input_id`, `current_attention_mask`, and `current_kv`.
|
280 |
+
else:
|
281 |
+
# LM stopped for a reason other than EOS or code_end` (e.g., max_new_tokens for the segment)
|
282 |
+
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
283 |
+
batch_interpreter_positions.append(interpreter_positions)
|
284 |
+
# At the end, return full sequence (prompt + completion)
|
285 |
+
break
|
286 |
+
else: # Executed if the loop finished due to max_turns without a break
|
287 |
+
batch_completion.append(cumulative_completion_ids.squeeze(0))
|
288 |
+
batch_interpreter_positions.append(interpreter_positions)
|
289 |
+
|
290 |
+
|
291 |
+
# Pad sequences in the batch to the same length for returning a single tensor
|
292 |
+
# This is a common step if you started with a batch loop.
|
293 |
+
# Alternatively, this function could return a list of tensors if lengths vary.
|
294 |
+
# For now, assuming you'll handle batch padding outside or return a list.
|
295 |
+
# The return type `torch.LongTensor` implies a padded batch.
|
296 |
+
padded_sequences = torch.nn.utils.rnn.pad_sequence(batch_completion, batch_first=True, padding_value=self.processing_class.pad_token_id)
|
297 |
+
|
298 |
+
return padded_sequences, batch_interpreter_positions
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
def _create_interpreter_mask(
|
303 |
+
self,
|
304 |
+
completion_ids: torch.Tensor,
|
305 |
+
interpreter_positions: list[list[tuple[int, int]]]
|
306 |
+
) -> torch.Tensor:
|
307 |
+
"""
|
308 |
+
Create interpreter mask from positions.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
completion_ids: Tensor of shape (batch_size, seq_length)
|
312 |
+
interpreter_positions: List[List[Tuple[start_idx, end_idx]]]
|
313 |
+
- Indices are relative to completion_ids
|
314 |
+
- start_idx: inclusive, end_idx: INCLUSIVE (unlike typical Python slicing)
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
interpreter_mask: Tensor of shape (batch_size, seq_length)
|
318 |
+
1 = model-generated token, 0 = interpreter token
|
319 |
+
"""
|
320 |
+
batch_size, seq_length = completion_ids.shape
|
321 |
+
|
322 |
+
# Initialize mask with all 1s (assume all tokens are model-generated)
|
323 |
+
interpreter_mask = torch.ones(batch_size, seq_length, dtype=torch.float, device=completion_ids.device)
|
324 |
+
|
325 |
+
# For each sequence in the batch
|
326 |
+
for batch_idx, positions_in_sequence in enumerate(interpreter_positions):
|
327 |
+
# For each interpreter section in this sequence
|
328 |
+
for start_idx, end_idx in positions_in_sequence:
|
329 |
+
# Clamp indices to valid range
|
330 |
+
start_idx = max(0, min(start_idx, seq_length - 1))
|
331 |
+
end_idx = max(0, min(end_idx, seq_length - 1))
|
332 |
+
|
333 |
+
# Zero out interpreter tokens (BOTH start and end inclusive)
|
334 |
+
if start_idx <= end_idx: # Changed from < to <=
|
335 |
+
interpreter_mask[batch_idx, start_idx:end_idx + 1] = 0 # Changed to end_idx + 1
|
336 |
+
|
337 |
+
return interpreter_mask
|
338 |
+
|
339 |
+
|
340 |
+
def _generate_and_score_completions(
|
341 |
+
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
342 |
+
) -> dict[str, Union[torch.Tensor, Any]]:
|
343 |
+
|
344 |
+
device = self.accelerator.device
|
345 |
+
mode = "train" if self.model.training else "eval"
|
346 |
+
|
347 |
+
prompts = [x["prompt"] for x in inputs]
|
348 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
349 |
+
prompt_inputs = self.processing_class(
|
350 |
+
text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
351 |
+
)
|
352 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
353 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
354 |
+
|
355 |
+
if self.max_prompt_length is not None:
|
356 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
357 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
358 |
+
|
359 |
+
|
360 |
+
# use custom multi-turn-w-tool-use Generate completions
|
361 |
+
completion_ids, interpreter_positions = self._retool_generate_with_interpreter(
|
362 |
+
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config,
|
363 |
+
eos_id = self.eos_id, interpreter_id = self.interpreter_id, code_id = self.code_id
|
364 |
+
)
|
365 |
+
|
366 |
+
|
367 |
+
# Mask everything after the first EOS token
|
368 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
369 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
370 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
371 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
372 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
373 |
+
|
374 |
+
|
375 |
+
# compute interpreter mask
|
376 |
+
interpreter_mask = self._create_interpreter_mask(completion_ids, interpreter_positions)
|
377 |
+
|
378 |
+
|
379 |
+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
380 |
+
if self.mask_truncated_completions:
|
381 |
+
truncated_completions = ~is_eos.any(dim=1)
|
382 |
+
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
|
383 |
+
|
384 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
385 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
386 |
+
|
387 |
+
|
388 |
+
# no need to return old_per_token_logps
|
389 |
+
|
390 |
+
# Extract ground truths from inputs
|
391 |
+
ground_truths = [x.get("answer") for x in inputs] # Adjust key name as needed
|
392 |
+
|
393 |
+
# Decode completions for reward computation
|
394 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
395 |
+
|
396 |
+
# Compute rewards and advantages
|
397 |
+
advantages = self._compute_rewards_and_advantages(
|
398 |
+
completions_text,
|
399 |
+
ground_truths,
|
400 |
+
device=device
|
401 |
+
)
|
402 |
+
|
403 |
+
|
404 |
+
# Log the metrics
|
405 |
+
if mode == "train":
|
406 |
+
self.state.num_input_tokens_seen += attention_mask.sum().item() # Skip gather
|
407 |
+
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
|
408 |
+
|
409 |
+
# Log completion lengths
|
410 |
+
completion_lengths = completion_mask.sum(1) # Skip gather
|
411 |
+
self._metrics[mode]["completions/mean_length"].append(completion_lengths.float().mean().item())
|
412 |
+
self._metrics[mode]["completions/min_length"].append(completion_lengths.float().min().item())
|
413 |
+
self._metrics[mode]["completions/max_length"].append(completion_lengths.float().max().item())
|
414 |
+
|
415 |
+
# Log terminated sequences
|
416 |
+
terminated_with_eos = is_eos.any(dim=1) # Skip gather
|
417 |
+
term_completion_lengths = completion_lengths[terminated_with_eos]
|
418 |
+
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
|
419 |
+
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
|
420 |
+
|
421 |
+
if len(term_completion_lengths) == 0:
|
422 |
+
term_completion_lengths = torch.zeros(1, device=device)
|
423 |
+
|
424 |
+
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
|
425 |
+
|
426 |
+
# Log rewards (simplified for single reward function)
|
427 |
+
advantages_tensor = advantages
|
428 |
+
self._metrics[mode]["rewards/binary_correctness/mean"].append(advantages_tensor.mean().item())
|
429 |
+
self._metrics[mode]["rewards/binary_correctness/std"].append(advantages_tensor.std().item())
|
430 |
+
|
431 |
+
|
432 |
+
# Log texts for debugging
|
433 |
+
self._textual_logs["prompt"].extend(prompts_text)
|
434 |
+
self._textual_logs["completion"].extend(completions_text)
|
435 |
+
self._textual_logs["rewards"]["binary_correctness"].extend(advantages.tolist())
|
436 |
+
|
437 |
+
return {
|
438 |
+
"prompt_ids": prompt_ids,
|
439 |
+
"prompt_mask": prompt_mask,
|
440 |
+
"completion_ids": completion_ids,
|
441 |
+
"completion_mask": completion_mask,
|
442 |
+
"interpreter_mask": interpreter_mask,
|
443 |
+
"advantages": advantages
|
444 |
+
}
|
445 |
+
|
446 |
+
|
447 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
448 |
+
@profiling_decorator
|
449 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
|
450 |
+
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
|
451 |
+
all_logps = []
|
452 |
+
for i in range(0, input_ids.size(0), batch_size):
|
453 |
+
input_ids_batch = input_ids[i : i + batch_size]
|
454 |
+
attention_mask_batch = attention_mask[i : i + batch_size]
|
455 |
+
|
456 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
457 |
+
logits = model(
|
458 |
+
input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1
|
459 |
+
).logits
|
460 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
461 |
+
input_ids_batch = input_ids_batch[:, -logits_to_keep:]
|
462 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
463 |
+
# See https://github.com/huggingface/trl/issues/2770
|
464 |
+
logits = logits[:, -logits_to_keep:]
|
465 |
+
# Divide logits by sampling temperature.
|
466 |
+
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
|
467 |
+
logits = logits / self.temperature
|
468 |
+
logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
|
469 |
+
all_logps.append(logps)
|
470 |
+
return torch.cat(all_logps, dim=0)
|
471 |
+
|
472 |
+
|
473 |
+
@staticmethod
|
474 |
+
def selective_log_softmax(logits, index):
|
475 |
+
"""
|
476 |
+
A memory-efficient implementation of the common `log_softmax -> gather` operation.
|
477 |
+
|
478 |
+
This function is equivalent to the following naive implementation:
|
479 |
+
```python
|
480 |
+
logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
481 |
+
```
|
482 |
+
|
483 |
+
Args:
|
484 |
+
logits (`torch.Tensor`):
|
485 |
+
Logits tensor of shape `(..., num_classes)`.
|
486 |
+
index (`torch.Tensor`):
|
487 |
+
Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
|
488 |
+
|
489 |
+
Returns:
|
490 |
+
`torch.Tensor`:
|
491 |
+
Gathered log probabilities with the same shape as `index`.
|
492 |
+
"""
|
493 |
+
if logits.dtype in [torch.float32, torch.float64]:
|
494 |
+
selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
|
495 |
+
# loop to reduce peak mem consumption
|
496 |
+
logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
497 |
+
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
498 |
+
else:
|
499 |
+
# logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
|
500 |
+
per_token_logps = []
|
501 |
+
for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption
|
502 |
+
row_logps = F.log_softmax(row_logits, dim=-1)
|
503 |
+
row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
|
504 |
+
per_token_logps.append(row_per_token_logps)
|
505 |
+
per_token_logps = torch.stack(per_token_logps)
|
506 |
+
return per_token_logps
|
507 |
+
|
508 |
+
|
509 |
+
def _compute_loss(self, model, inputs):
|
510 |
+
# Compute the per-token log probabilities for the model
|
511 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
512 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
513 |
+
|
514 |
+
# Added for ReTool Trainer
|
515 |
+
interpreter_mask = inputs["interpreter_mask"]
|
516 |
+
final_mask = interpreter_mask * completion_mask
|
517 |
+
|
518 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
519 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
520 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
521 |
+
|
522 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
523 |
+
|
524 |
+
with torch.no_grad():
|
525 |
+
ref_per_token_logps = self._get_per_token_logps(
|
526 |
+
self.ref_model, input_ids, attention_mask, logits_to_keep
|
527 |
+
)
|
528 |
+
# Compute the KL divergence between the model and the reference model
|
529 |
+
if self.beta != 0.0:
|
530 |
+
per_token_kl = (
|
531 |
+
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
532 |
+
)
|
533 |
+
|
534 |
+
# Compute the loss
|
535 |
+
advantages = inputs["advantages"]
|
536 |
+
|
537 |
+
old_per_token_logps = ref_per_token_logps
|
538 |
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
539 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
540 |
+
|
541 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
542 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
543 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
544 |
+
if self.beta != 0.0:
|
545 |
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
546 |
+
|
547 |
+
|
548 |
+
# For PPO loss
|
549 |
+
masked_loss = per_token_loss * final_mask
|
550 |
+
total_valid_tokens = final_mask.sum() + 1e-8 # Avoid division by zero
|
551 |
+
loss = masked_loss.sum() / total_valid_tokens
|
552 |
+
|
553 |
+
""" --- """
|
554 |
+
|
555 |
+
# Log the metrics
|
556 |
+
mode = "train" if self.model.training else "eval"
|
557 |
+
|
558 |
+
if self.beta != 0.0:
|
559 |
+
mean_kl = (per_token_kl * final_mask).sum() / final_mask.sum()
|
560 |
+
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
|
561 |
+
|
562 |
+
# Compute the clipped probability ratios
|
563 |
+
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
|
564 |
+
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
|
565 |
+
is_region_clipped = is_low_clipped | is_high_clipped
|
566 |
+
|
567 |
+
low_clip = (is_low_clipped * final_mask).sum() / final_mask.sum()
|
568 |
+
high_clip = (is_high_clipped * final_mask).sum() / final_mask.sum()
|
569 |
+
clip_ratio = (is_region_clipped * final_mask).sum() / final_mask.sum()
|
570 |
+
|
571 |
+
gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
|
572 |
+
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
|
573 |
+
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
|
574 |
+
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
|
575 |
+
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
|
576 |
+
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
577 |
+
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
|
578 |
+
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
579 |
+
return loss
|