bird-of-paradise commited on
Commit
0690c9f
·
verified ·
1 Parent(s): a0dec77

adding test suite -- first commit

Browse files
src/test/retool_genertion_with_cache_test.ipynb ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "colab": {
24
+ "base_uri": "https://localhost:8080/"
25
+ },
26
+ "id": "ejiXlq27sck1",
27
+ "outputId": "d2c846e5-97da-4533-d23f-1cb876d67069"
28
+ },
29
+ "outputs": [
30
+ {
31
+ "output_type": "stream",
32
+ "name": "stdout",
33
+ "text": [
34
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.52.4)\n",
35
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.18.0)\n",
36
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.32.4)\n",
37
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2.0.2)\n",
38
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.2)\n",
39
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
40
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
41
+ "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
42
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.1)\n",
43
+ "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.3)\n",
44
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
45
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (2025.3.2)\n",
46
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (4.14.0)\n",
47
+ "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (1.1.2)\n",
48
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.4.2)\n",
49
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
50
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.4.0)\n",
51
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2025.4.26)\n"
52
+ ]
53
+ }
54
+ ],
55
+ "source": [
56
+ "! pip install transformers"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "source": [
62
+ "! pip install profiling-decorator"
63
+ ],
64
+ "metadata": {
65
+ "colab": {
66
+ "base_uri": "https://localhost:8080/"
67
+ },
68
+ "id": "3Sa_Bpi1srA9",
69
+ "outputId": "6ad4ffd6-1058-4097-acb2-21978fe27ca0"
70
+ },
71
+ "execution_count": 2,
72
+ "outputs": [
73
+ {
74
+ "output_type": "stream",
75
+ "name": "stdout",
76
+ "text": [
77
+ "Collecting profiling-decorator\n",
78
+ " Downloading profiling_decorator-0.0.6-py3-none-any.whl.metadata (6.2 kB)\n",
79
+ "Downloading profiling_decorator-0.0.6-py3-none-any.whl (9.2 kB)\n",
80
+ "Installing collected packages: profiling-decorator\n",
81
+ "Successfully installed profiling-decorator-0.0.6\n"
82
+ ]
83
+ }
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "source": [],
89
+ "metadata": {
90
+ "id": "9IRlvyF-J4Mp"
91
+ },
92
+ "execution_count": null,
93
+ "outputs": []
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "source": [],
98
+ "metadata": {
99
+ "id": "eV3CXXy6J47P"
100
+ },
101
+ "execution_count": null,
102
+ "outputs": []
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "def test_updated_retool_implementation():\n",
108
+ " # 1. Setup model, tokenizer, and device\n",
109
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
110
+ " import torch\n",
111
+ " import transformers\n",
112
+ " import re\n",
113
+ " from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache\n",
114
+ "\n",
115
+ " # Use a model that fits in memory\n",
116
+ " model_name = \"gpt2-medium\"\n",
117
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
118
+ "\n",
119
+ " # Ensure padding token is set\n",
120
+ " if tokenizer.pad_token is None:\n",
121
+ " tokenizer.pad_token = tokenizer.eos_token\n",
122
+ "\n",
123
+ " # Check device\n",
124
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
125
+ " print(f\"Using device: {device}\")\n",
126
+ "\n",
127
+ " # Load model to device\n",
128
+ " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n",
129
+ "\n",
130
+ " # 2. Add special tokens\n",
131
+ " special_tokens = {\n",
132
+ " 'additional_special_tokens': ['<code>', '</code>', '<interpreter>', '</interpreter>']\n",
133
+ " }\n",
134
+ " tokenizer.add_special_tokens(special_tokens)\n",
135
+ " model.resize_token_embeddings(len(tokenizer))\n",
136
+ "\n",
137
+ " # Get token IDs\n",
138
+ " code_start_id = tokenizer.convert_tokens_to_ids('<code>')\n",
139
+ " code_end_id = tokenizer.convert_tokens_to_ids('</code>')\n",
140
+ " interpreter_start_id = tokenizer.convert_tokens_to_ids('<interpreter>')\n",
141
+ " interpreter_end_id = tokenizer.convert_tokens_to_ids('</interpreter>')\n",
142
+ "\n",
143
+ " print(f\"EOS token ID: {tokenizer.eos_token_id}\")\n",
144
+ " print(f\"Pad token ID: {tokenizer.pad_token_id}\")\n",
145
+ " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n",
146
+ " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n",
147
+ "\n",
148
+ " # 3. Create a test version of your ReToolTrainer with custom generation\n",
149
+ " class TestReToolTrainer:\n",
150
+ " def __init__(self, model, tokenizer, device):\n",
151
+ " self.model = model\n",
152
+ " self.processing_class = tokenizer\n",
153
+ " self.device = device\n",
154
+ " self.temperature = 0.7\n",
155
+ " self.top_p = 0.9\n",
156
+ " self.top_k = 50\n",
157
+ "\n",
158
+ " # Ensure pad token is set\n",
159
+ " if self.processing_class.pad_token is None:\n",
160
+ " self.processing_class.pad_token = self.processing_class.eos_token\n",
161
+ "\n",
162
+ " def _execute_code(self, code_block):\n",
163
+ " \"\"\"Mock code execution\"\"\"\n",
164
+ " print(f\"\\n==== EXECUTING CODE ====\")\n",
165
+ " print(f\"{code_block}\")\n",
166
+ " print(f\"========================\\n\")\n",
167
+ " return \"0 1 1 2 3\"\n",
168
+ "\n",
169
+ " def _custom_generate(self, input_ids, attention_mask=None, past_key_values=None, max_new_tokens=50, eos_token_ids=None):\n",
170
+ " \"\"\"Custom generation function that avoids KV cache issues\"\"\"\n",
171
+ " if attention_mask is None:\n",
172
+ " attention_mask = torch.ones_like(input_ids)\n",
173
+ "\n",
174
+ " if eos_token_ids is None:\n",
175
+ " eos_token_ids = [self.processing_class.eos_token_id]\n",
176
+ "\n",
177
+ " # Initialize\n",
178
+ " current_ids = input_ids.clone()\n",
179
+ " current_mask = attention_mask.clone()\n",
180
+ " current_kv = past_key_values\n",
181
+ "\n",
182
+ " # Generate tokens in batches for efficiency\n",
183
+ " all_tokens = []\n",
184
+ " batch_size = 10 # Process this many tokens at once\n",
185
+ "\n",
186
+ " for start_idx in range(0, max_new_tokens, batch_size):\n",
187
+ " # How many tokens to generate in this batch\n",
188
+ " batch_tokens = min(batch_size, max_new_tokens - start_idx)\n",
189
+ "\n",
190
+ " # Accumulate new tokens\n",
191
+ " new_tokens = []\n",
192
+ "\n",
193
+ " for _ in range(batch_tokens):\n",
194
+ " # Forward pass with proper cache handling\n",
195
+ " with torch.no_grad():\n",
196
+ " outputs = self.model(\n",
197
+ " input_ids=current_ids if current_kv is None else current_ids[:, -1:],\n",
198
+ " attention_mask=current_mask if current_kv is None else current_mask[:, -1:],\n",
199
+ " past_key_values=DynamicCache.from_legacy_cache(current_kv) if current_kv is not None else None,\n",
200
+ " use_cache=True\n",
201
+ " )\n",
202
+ "\n",
203
+ " # Sample next token\n",
204
+ " next_token_logits = outputs.logits[:, -1, :] / self.temperature\n",
205
+ " filtered_logits = self._filter_logits(next_token_logits)\n",
206
+ " probs = torch.nn.functional.softmax(filtered_logits, dim=-1)\n",
207
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
208
+ "\n",
209
+ " # Add to accumulated tokens\n",
210
+ " token_id = next_token.item()\n",
211
+ " new_tokens.append(token_id)\n",
212
+ "\n",
213
+ " # Update for next iteration\n",
214
+ " current_ids = torch.cat([current_ids, next_token], dim=1)\n",
215
+ " token_mask = torch.ones((1, 1), device=current_mask.device, dtype=current_mask.dtype)\n",
216
+ " current_mask = torch.cat([current_mask, token_mask], dim=1)\n",
217
+ " current_kv = outputs.past_key_values\n",
218
+ "\n",
219
+ " # Check for stop tokens - include both EOS and code_end\n",
220
+ " if token_id in eos_token_ids:\n",
221
+ " break\n",
222
+ "\n",
223
+ " # Add batch tokens to overall result\n",
224
+ " all_tokens.extend(new_tokens)\n",
225
+ "\n",
226
+ " # Check if we hit a stop token\n",
227
+ " if len(new_tokens) < batch_tokens:\n",
228
+ " break\n",
229
+ "\n",
230
+ " # Convert to tensor\n",
231
+ " result = torch.tensor([all_tokens], device=input_ids.device)\n",
232
+ " return result, current_kv\n",
233
+ "\n",
234
+ " def _filter_logits(self, logits):\n",
235
+ " \"\"\"Apply top-k and top-p filtering\"\"\"\n",
236
+ " if self.top_k > 0:\n",
237
+ " top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)\n",
238
+ " logits[0, :] = torch.full_like(logits[0, :], float('-inf'))\n",
239
+ " logits[0, top_k_indices[0]] = top_k_logits[0]\n",
240
+ "\n",
241
+ " if self.top_p < 1.0:\n",
242
+ " sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)\n",
243
+ " cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)\n",
244
+ "\n",
245
+ " # Remove tokens with cumulative probability above threshold\n",
246
+ " sorted_indices_to_remove = cumulative_probs > self.top_p\n",
247
+ " # Shift the indices to the right to keep the first token above threshold\n",
248
+ " sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()\n",
249
+ " sorted_indices_to_remove[:, 0] = 0\n",
250
+ "\n",
251
+ " # Scatter sorted tensors to original indexing\n",
252
+ " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n",
253
+ " logits[indices_to_remove] = float('-inf')\n",
254
+ "\n",
255
+ " return logits\n",
256
+ "\n",
257
+ " def _retool_generate_with_interpreter(self, prompt_ids_batch, attention_mask_batch, eos_id, interpreter_id, code_id, max_turns=10):\n",
258
+ " \"\"\"Your updated implementation with custom generation\"\"\"\n",
259
+ " batch_size = prompt_ids_batch.size(0)\n",
260
+ " batch_completion = []\n",
261
+ " batch_interpreter_positions = []\n",
262
+ "\n",
263
+ " for i in range(batch_size):\n",
264
+ " print(f\"Processing batch item {i+1}/{batch_size}\")\n",
265
+ "\n",
266
+ " # Initialize\n",
267
+ " current_input_id = prompt_ids_batch[i:i+1]\n",
268
+ " current_attention_mask = attention_mask_batch[i:i+1]\n",
269
+ " current_kv = None\n",
270
+ "\n",
271
+ " # Track the completion part (no prompt)\n",
272
+ " cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)\n",
273
+ " interpreter_positions = []\n",
274
+ "\n",
275
+ " for turn_idx in range(max_turns):\n",
276
+ " # Check if input is empty\n",
277
+ " if current_input_id.size(1) == 0:\n",
278
+ " print(f\"Turn {turn_idx + 1}: Input is empty, breaking loop\")\n",
279
+ " break\n",
280
+ "\n",
281
+ " print(f\"\\n--- Turn {turn_idx + 1} ---\")\n",
282
+ " print(f\"Current input: {self.processing_class.decode(current_input_id[0])}\")\n",
283
+ " print(f\"KV cache present: {current_kv is not None}\")\n",
284
+ "\n",
285
+ " # Generate with custom function\n",
286
+ " newly_generated_tokens, current_kv = self._custom_generate(\n",
287
+ " input_ids=current_input_id,\n",
288
+ " attention_mask=current_attention_mask,\n",
289
+ " past_key_values=current_kv,\n",
290
+ " max_new_tokens=30,\n",
291
+ " eos_token_ids=[eos_id, code_id[1]]\n",
292
+ " )\n",
293
+ "\n",
294
+ " # Display generated text\n",
295
+ " print(f\"Generated: {self.processing_class.decode(newly_generated_tokens[0])}\")\n",
296
+ "\n",
297
+ " # Add to cumulative completion\n",
298
+ " cumulative_completion_ids = torch.cat([cumulative_completion_ids, newly_generated_tokens], dim=1)\n",
299
+ "\n",
300
+ " # Check last token\n",
301
+ " last_token_id = newly_generated_tokens[0, -1].item() if newly_generated_tokens.size(1) > 0 else None\n",
302
+ " print(f\"Last token ID: {last_token_id}\")\n",
303
+ "\n",
304
+ " # Check for end conditions\n",
305
+ " if last_token_id == eos_id:\n",
306
+ " print(\"Found EOS token, ending generation\")\n",
307
+ " break\n",
308
+ "\n",
309
+ " # Check for code end token\n",
310
+ " if last_token_id == code_id[1]:\n",
311
+ " print(\"Found </code> token, executing code\")\n",
312
+ "\n",
313
+ " # Extract code from the full text\n",
314
+ " full_text = self.processing_class.decode(\n",
315
+ " torch.cat([prompt_ids_batch[i], cumulative_completion_ids[0]], dim=0)\n",
316
+ " )\n",
317
+ " code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)\n",
318
+ "\n",
319
+ " if code_match:\n",
320
+ " code_block = code_match.group(1).strip()\n",
321
+ "\n",
322
+ " # Execute code\n",
323
+ " interpreter_text = self._execute_code(code_block)\n",
324
+ "\n",
325
+ " # Format and add interpreter output\n",
326
+ " formatted_feedback = f\"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}\"\n",
327
+ " interpreter_ids = self.processing_class(\n",
328
+ " formatted_feedback,\n",
329
+ " return_tensors=\"pt\",\n",
330
+ " add_special_tokens=False\n",
331
+ " ).input_ids.to(prompt_ids_batch.device)\n",
332
+ "\n",
333
+ " # Record positions\n",
334
+ " interpreter_start_idx = cumulative_completion_ids.size(1)\n",
335
+ " cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_ids], dim=1)\n",
336
+ " interpreter_end_idx = cumulative_completion_ids.size(1) - 1\n",
337
+ " interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))\n",
338
+ "\n",
339
+ " print(f\"Added interpreter output: {formatted_feedback}\")\n",
340
+ "\n",
341
+ " # Set up for next turn\n",
342
+ " current_input_id = interpreter_ids\n",
343
+ " current_attention_mask = torch.ones_like(current_input_id)\n",
344
+ " # Keep current_kv from previous generation\n",
345
+ " else:\n",
346
+ " print(\"No code block found despite </code> token\")\n",
347
+ " break\n",
348
+ " else:\n",
349
+ " # Continue with the newly generated tokens\n",
350
+ " current_input_id = newly_generated_tokens\n",
351
+ " current_attention_mask = torch.ones_like(current_input_id)\n",
352
+ "\n",
353
+ " # Add to batch results\n",
354
+ " batch_completion.append(cumulative_completion_ids.squeeze(0))\n",
355
+ " batch_interpreter_positions.append(interpreter_positions)\n",
356
+ "\n",
357
+ " # Pad sequences\n",
358
+ " if len(batch_completion) > 0:\n",
359
+ " # Ensure padding_value is a valid integer\n",
360
+ " padding_value = self.processing_class.pad_token_id\n",
361
+ " if padding_value is None:\n",
362
+ " padding_value = 0 # Use 0 as a default if pad_token_id is None\n",
363
+ "\n",
364
+ " padded_sequences = torch.nn.utils.rnn.pad_sequence(\n",
365
+ " batch_completion,\n",
366
+ " batch_first=True,\n",
367
+ " padding_value=padding_value\n",
368
+ " )\n",
369
+ " else:\n",
370
+ " padded_sequences = torch.empty((0, 0), dtype=torch.long, device=prompt_ids_batch.device)\n",
371
+ "\n",
372
+ " return padded_sequences, batch_interpreter_positions\n",
373
+ "\n",
374
+ " # 4. Create test instance\n",
375
+ " tester = TestReToolTrainer(model, tokenizer, device)\n",
376
+ "\n",
377
+ " # 5. Create a test prompt with a complete code block\n",
378
+ " prompt = \"\"\"Let me solve this problem with code:\n",
379
+ "\n",
380
+ "<code>\n",
381
+ "def fibonacci(n):\n",
382
+ " a, b = 0, 1\n",
383
+ " result = []\n",
384
+ " for _ in range(n):\n",
385
+ " result.append(a)\n",
386
+ " a, b = b, a + b\n",
387
+ " return result\n",
388
+ "\n",
389
+ "print(fibonacci(5))\n",
390
+ "</code>\"\"\"\n",
391
+ "\n",
392
+ " # 6. Run the test\n",
393
+ " try:\n",
394
+ " print(\"\\n=== Testing Updated ReTool Implementation ===\\n\")\n",
395
+ "\n",
396
+ " # Encode the prompt\n",
397
+ " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
398
+ " attention_mask = torch.ones_like(prompt_ids)\n",
399
+ "\n",
400
+ " # Run the generation\n",
401
+ " completions, positions = tester._retool_generate_with_interpreter(\n",
402
+ " prompt_ids_batch=prompt_ids,\n",
403
+ " attention_mask_batch=attention_mask,\n",
404
+ " eos_id=tokenizer.eos_token_id,\n",
405
+ " interpreter_id=[interpreter_start_id, interpreter_end_id],\n",
406
+ " code_id=[code_start_id, code_end_id],\n",
407
+ " max_turns=3\n",
408
+ " )\n",
409
+ "\n",
410
+ " # Display results\n",
411
+ " print(\"\\n=== Final Results ===\\n\")\n",
412
+ " print(\"Generated completion:\")\n",
413
+ " print(tokenizer.decode(completions[0]))\n",
414
+ "\n",
415
+ " print(\"\\nFull text:\")\n",
416
+ " print(tokenizer.decode(torch.cat([prompt_ids[0], completions[0]])))\n",
417
+ "\n",
418
+ " print(\"\\nInterpreter positions:\", positions)\n",
419
+ "\n",
420
+ " except Exception as e:\n",
421
+ " import traceback\n",
422
+ " print(f\"Error during testing: {e}\")\n",
423
+ " traceback.print_exc()\n",
424
+ "\n",
425
+ "# Run the test\n",
426
+ "test_updated_retool_implementation()"
427
+ ],
428
+ "metadata": {
429
+ "colab": {
430
+ "base_uri": "https://localhost:8080/"
431
+ },
432
+ "id": "4_E6Eo7EHC_8",
433
+ "outputId": "35b195d9-b0ff-4ddf-c216-fba1c83f40e2"
434
+ },
435
+ "execution_count": 9,
436
+ "outputs": [
437
+ {
438
+ "output_type": "stream",
439
+ "name": "stdout",
440
+ "text": [
441
+ "Using device: cpu\n",
442
+ "EOS token ID: 50256\n",
443
+ "Pad token ID: 50256\n",
444
+ "Code tokens: 50257, 50258\n",
445
+ "Interpreter tokens: 50259, 50260\n",
446
+ "\n",
447
+ "=== Testing Updated ReTool Implementation ===\n",
448
+ "\n",
449
+ "Processing batch item 1/1\n",
450
+ "\n",
451
+ "--- Turn 1 ---\n",
452
+ "Current input: Let me solve this problem with code:\n",
453
+ "\n",
454
+ "<code>\n",
455
+ "def fibonacci(n):\n",
456
+ " a, b = 0, 1\n",
457
+ " result = []\n",
458
+ " for _ in range(n):\n",
459
+ " result.append(a)\n",
460
+ " a, b = b, a + b\n",
461
+ " return result\n",
462
+ "\n",
463
+ "print(fibonacci(5))\n",
464
+ "</code>\n",
465
+ "KV cache present: False\n",
466
+ "Generated: \n",
467
+ "\n",
468
+ "def fibonacci(n):\n",
469
+ "\n",
470
+ " a, b = 0, 1\n",
471
+ "\n",
472
+ " result = []\n",
473
+ "\n",
474
+ "Last token ID: 198\n",
475
+ "\n",
476
+ "--- Turn 2 ---\n",
477
+ "Current input: \n",
478
+ "\n",
479
+ "def fibonacci(n):\n",
480
+ "\n",
481
+ " a, b = 0, 1\n",
482
+ "\n",
483
+ " result = []\n",
484
+ "\n",
485
+ "KV cache present: True\n",
486
+ "Generated: \n",
487
+ " a, b = b, a + b\n",
488
+ "\n",
489
+ " return result\n",
490
+ "\n",
491
+ "print(fibon\n",
492
+ "Last token ID: 261\n",
493
+ "\n",
494
+ "--- Turn 3 ---\n",
495
+ "Current input: \n",
496
+ " a, b = b, a + b\n",
497
+ "\n",
498
+ " return result\n",
499
+ "\n",
500
+ "print(fibon\n",
501
+ "KV cache present: True\n",
502
+ "Generated: acci(5))\n",
503
+ "\n",
504
+ "So the first two methods are all the same, the last one is a little different, and the second one is the\n",
505
+ "Last token ID: 262\n",
506
+ "\n",
507
+ "=== Final Results ===\n",
508
+ "\n",
509
+ "Generated completion:\n",
510
+ "\n",
511
+ "\n",
512
+ "def fibonacci(n):\n",
513
+ "\n",
514
+ " a, b = 0, 1\n",
515
+ "\n",
516
+ " result = []\n",
517
+ "\n",
518
+ " a, b = b, a + b\n",
519
+ "\n",
520
+ " return result\n",
521
+ "\n",
522
+ "print(fibonacci(5))\n",
523
+ "\n",
524
+ "So the first two methods are all the same, the last one is a little different, and the second one is the\n",
525
+ "\n",
526
+ "Full text:\n",
527
+ "Let me solve this problem with code:\n",
528
+ "\n",
529
+ "<code>\n",
530
+ "def fibonacci(n):\n",
531
+ " a, b = 0, 1\n",
532
+ " result = []\n",
533
+ " for _ in range(n):\n",
534
+ " result.append(a)\n",
535
+ " a, b = b, a + b\n",
536
+ " return result\n",
537
+ "\n",
538
+ "print(fibonacci(5))\n",
539
+ "</code>\n",
540
+ "\n",
541
+ "def fibonacci(n):\n",
542
+ "\n",
543
+ " a, b = 0, 1\n",
544
+ "\n",
545
+ " result = []\n",
546
+ "\n",
547
+ " a, b = b, a + b\n",
548
+ "\n",
549
+ " return result\n",
550
+ "\n",
551
+ "print(fibonacci(5))\n",
552
+ "\n",
553
+ "So the first two methods are all the same, the last one is a little different, and the second one is the\n",
554
+ "\n",
555
+ "Interpreter positions: [[]]\n"
556
+ ]
557
+ }
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "source": [],
563
+ "metadata": {
564
+ "id": "Z0EHHkP3J7Ox"
565
+ },
566
+ "execution_count": null,
567
+ "outputs": []
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "source": [
572
+ "def test_retool_core_functionality():\n",
573
+ " # 1. Create minimal model and tokenizer\n",
574
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
575
+ " import torch\n",
576
+ " import transformers\n",
577
+ " import re\n",
578
+ "\n",
579
+ " # Use a small model for testing\n",
580
+ " model_name = \"gpt2-medium\"\n",
581
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
582
+ "\n",
583
+ " # Check if CUDA is available\n",
584
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
585
+ " print(f\"Using device: {device}\")\n",
586
+ "\n",
587
+ " # Load model directly to the selected device\n",
588
+ " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n",
589
+ "\n",
590
+ " # 2. Add special tokens to the tokenizer\n",
591
+ " special_tokens = {\n",
592
+ " 'additional_special_tokens': ['<code>', '</code>', '<interpreter>', '</interpreter>']\n",
593
+ " }\n",
594
+ " tokenizer.add_special_tokens(special_tokens)\n",
595
+ " model.resize_token_embeddings(len(tokenizer))\n",
596
+ "\n",
597
+ " # Get token IDs for special tokens\n",
598
+ " code_start_id = tokenizer.convert_tokens_to_ids('<code>')\n",
599
+ " code_end_id = tokenizer.convert_tokens_to_ids('</code>')\n",
600
+ " interpreter_start_id = tokenizer.convert_tokens_to_ids('<interpreter>')\n",
601
+ " interpreter_end_id = tokenizer.convert_tokens_to_ids('</interpreter>')\n",
602
+ "\n",
603
+ " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n",
604
+ " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n",
605
+ "\n",
606
+ " # 3. Create a simplified implementation of _retool_generate_with_interpreter\n",
607
+ " def simplified_generate_with_interpreter(model, tokenizer, prompt_text, device):\n",
608
+ " \"\"\"Simplified version focusing just on the core functionality\"\"\"\n",
609
+ " # Step 1: Tokenize the prompt\n",
610
+ " prompt_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\").to(device)\n",
611
+ "\n",
612
+ " # Initialize tracking variables\n",
613
+ " cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=device)\n",
614
+ " interpreter_positions = []\n",
615
+ "\n",
616
+ " # Step 2: Extract a code block and execute it\n",
617
+ " full_text = prompt_text\n",
618
+ " code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)\n",
619
+ "\n",
620
+ " if code_match:\n",
621
+ " code_block = code_match.group(1).strip()\n",
622
+ " print(f\"Found code block: {code_block}\")\n",
623
+ "\n",
624
+ " # Mock code execution\n",
625
+ " interpreter_output = \"0 1 1 2 3\"\n",
626
+ " print(f\"Code execution result: {interpreter_output}\")\n",
627
+ "\n",
628
+ " # Format interpreter feedback\n",
629
+ " interpreter_text = f\"<interpreter>{interpreter_output}</interpreter>\"\n",
630
+ " interpreter_ids = tokenizer.encode(\n",
631
+ " interpreter_text,\n",
632
+ " return_tensors=\"pt\",\n",
633
+ " add_special_tokens=False\n",
634
+ " ).to(device)\n",
635
+ "\n",
636
+ " # Step 3: Generate a continuation after the interpreter output\n",
637
+ " # First, create a sequence with prompt + interpreter output\n",
638
+ " combined_input = prompt_text + interpreter_text\n",
639
+ " combined_ids = tokenizer.encode(combined_input, return_tensors=\"pt\").to(device)\n",
640
+ "\n",
641
+ " # Generate continuation\n",
642
+ " with torch.no_grad():\n",
643
+ " continuation_outputs = model.generate(\n",
644
+ " input_ids=combined_ids,\n",
645
+ " max_new_tokens=50,\n",
646
+ " do_sample=True,\n",
647
+ " temperature=0.7,\n",
648
+ " top_p=0.9,\n",
649
+ " pad_token_id=tokenizer.pad_token_id,\n",
650
+ " eos_token_id=tokenizer.eos_token_id,\n",
651
+ " return_dict_in_generate=True,\n",
652
+ " cache_implementation= 'offloaded',\n",
653
+ " )\n",
654
+ "\n",
655
+ " # Extract only the newly generated tokens\n",
656
+ " continuation_tokens = continuation_outputs.sequences[:, combined_ids.size(1):]\n",
657
+ "\n",
658
+ " # Combine everything for the final result\n",
659
+ " # The completion consists of: interpreter output + continuation\n",
660
+ " cumulative_completion_ids = torch.cat([interpreter_ids, continuation_tokens], dim=1)\n",
661
+ "\n",
662
+ " # Record the interpreter position\n",
663
+ " interpreter_positions.append((0, interpreter_ids.size(1) - 1))\n",
664
+ "\n",
665
+ " print(f\"Generated continuation: {tokenizer.decode(continuation_tokens[0])}\")\n",
666
+ " else:\n",
667
+ " print(\"No code block found in the prompt.\")\n",
668
+ "\n",
669
+ " return cumulative_completion_ids, interpreter_positions\n",
670
+ "\n",
671
+ " # 4. Test with a prompt that has a complete code block\n",
672
+ " prompt = \"\"\"Let's calculate Fibonacci numbers in Python:\n",
673
+ "\n",
674
+ "<code>\n",
675
+ "def fibonacci(n):\n",
676
+ " a, b = 0, 1\n",
677
+ " result = []\n",
678
+ " for _ in range(n):\n",
679
+ " result.append(a)\n",
680
+ " a, b = b, a + b\n",
681
+ " return result\n",
682
+ "\n",
683
+ "print(fibonacci(5))\n",
684
+ "</code>\"\"\"\n",
685
+ "\n",
686
+ " # 5. Run our simplified test\n",
687
+ " try:\n",
688
+ " print(\"\\n--- Testing Core Functionality ---\")\n",
689
+ " completion, positions = simplified_generate_with_interpreter(model, tokenizer, prompt, device)\n",
690
+ "\n",
691
+ " print(\"\\n--- Final Results ---\")\n",
692
+ " print(\"Completion:\")\n",
693
+ " print(tokenizer.decode(completion[0]))\n",
694
+ " print(\"\\nInterpreter positions:\", positions)\n",
695
+ "\n",
696
+ " # 6. Now also test ReToolTrainer to verify core code execution functionality\n",
697
+ " print(\"\\n--- Testing ReToolTrainer with Direct Injection ---\")\n",
698
+ "\n",
699
+ " # Setup trainer\n",
700
+ " trainer = ReToolTrainer(\n",
701
+ " model=model,\n",
702
+ " processing_class=tokenizer,\n",
703
+ " args=transformers.TrainingArguments(\n",
704
+ " output_dir=\"./test_output\",\n",
705
+ " per_device_train_batch_size=1,\n",
706
+ " ),\n",
707
+ " train_dataset=None,\n",
708
+ " eval_dataset=None,\n",
709
+ " max_turns=3,\n",
710
+ " interpreter_id=[interpreter_start_id, interpreter_end_id],\n",
711
+ " code_id=[code_start_id, code_end_id],\n",
712
+ " eos_id=tokenizer.eos_token_id\n",
713
+ " )\n",
714
+ "\n",
715
+ " # Override the _execute_code method\n",
716
+ " def mock_execute_code(self, code_block):\n",
717
+ " print(f\"Mock executing code: {code_block}\")\n",
718
+ " return \"0 1 1 2 3\"\n",
719
+ "\n",
720
+ " original_execute_code = trainer._execute_code\n",
721
+ " trainer._execute_code = mock_execute_code.__get__(trainer, ReToolTrainer)\n",
722
+ "\n",
723
+ " # Create a sequence that has a prompt and ends with </code>\n",
724
+ " # This is to simulate that the model has generated a complete code block\n",
725
+ " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
726
+ " attention_mask = torch.ones_like(prompt_ids)\n",
727
+ "\n",
728
+ " # Directly inject a simulated generation with a code block\n",
729
+ " # Create a custom testing function\n",
730
+ " def test_execute_code_and_continue(self, prompt_ids, attention_mask):\n",
731
+ " \"\"\"Test just the code execution and continuation part\"\"\"\n",
732
+ " print(\"Testing code execution and continuation...\")\n",
733
+ " device = next(self.model.parameters()).device\n",
734
+ " prompt_ids = prompt_ids.to(device)\n",
735
+ " attention_mask = attention_mask.to(device)\n",
736
+ "\n",
737
+ " # Extract code from the prompt\n",
738
+ " full_text = self.processing_class.decode(prompt_ids[0])\n",
739
+ " code_match = re.search(r'<code>(.*?)</code>', full_text, re.DOTALL)\n",
740
+ "\n",
741
+ " if not code_match:\n",
742
+ " print(\"No code block found in the prompt!\")\n",
743
+ " return None, []\n",
744
+ "\n",
745
+ " code_block = code_match.group(1).strip()\n",
746
+ " print(f\"Executing code block: {code_block}\")\n",
747
+ "\n",
748
+ " # Execute the code\n",
749
+ " interpreter_text = self._execute_code(code_block)\n",
750
+ "\n",
751
+ " # Format and tokenize the interpreter output\n",
752
+ " formatted_feedback = f\"{self.processing_class.decode(self.interpreter_id[0])}{interpreter_text}{self.processing_class.decode(self.interpreter_id[1])}\"\n",
753
+ " interpreter_ids = self.processing_class(\n",
754
+ " formatted_feedback,\n",
755
+ " return_tensors=\"pt\",\n",
756
+ " add_special_tokens=False\n",
757
+ " ).input_ids.to(device)\n",
758
+ "\n",
759
+ " # Record position (relative to completion only)\n",
760
+ " interpreter_positions = [(0, interpreter_ids.size(1) - 1)]\n",
761
+ "\n",
762
+ " # Combine prompt with interpreter output for continuation\n",
763
+ " combined_ids = torch.cat([prompt_ids, interpreter_ids], dim=1)\n",
764
+ " combined_mask = torch.ones_like(combined_ids)\n",
765
+ "\n",
766
+ " # Generate continuation\n",
767
+ " continuation_outputs = self.model.generate(\n",
768
+ " input_ids=combined_ids,\n",
769
+ " attention_mask=combined_mask,\n",
770
+ " max_new_tokens=50,\n",
771
+ " do_sample=True,\n",
772
+ " temperature=0.7,\n",
773
+ " top_p=0.9,\n",
774
+ " pad_token_id=self.processing_class.pad_token_id,\n",
775
+ " eos_token_id=self.eos_id,\n",
776
+ " return_dict_in_generate=True,\n",
777
+ " cache_implementation= 'offloaded',\n",
778
+ " )\n",
779
+ "\n",
780
+ " # Extract only the newly generated continuation\n",
781
+ " continuation_tokens = continuation_outputs.sequences[:, combined_ids.size(1):]\n",
782
+ "\n",
783
+ " # Full completion is: interpreter output + continuation\n",
784
+ " completion = torch.cat([interpreter_ids, continuation_tokens], dim=1)\n",
785
+ "\n",
786
+ " return completion, interpreter_positions\n",
787
+ "\n",
788
+ " # Add the test method to the trainer\n",
789
+ " trainer.test_execute_code_and_continue = test_execute_code_and_continue.__get__(trainer, ReToolTrainer)\n",
790
+ "\n",
791
+ " # Run the test\n",
792
+ " completion, positions = trainer.test_execute_code_and_continue(prompt_ids, attention_mask)\n",
793
+ "\n",
794
+ " print(\"\\n--- Trainer Test Results ---\")\n",
795
+ " if completion is not None:\n",
796
+ " print(\"Completion:\")\n",
797
+ " print(tokenizer.decode(completion[0]))\n",
798
+ " print(\"\\nInterpreter positions:\", positions)\n",
799
+ "\n",
800
+ " except Exception as e:\n",
801
+ " import traceback\n",
802
+ " print(f\"Error during testing: {e}\")\n",
803
+ " traceback.print_exc()\n",
804
+ " finally:\n",
805
+ " # Restore original method if needed\n",
806
+ " if 'trainer' in locals() and 'original_execute_code' in locals():\n",
807
+ " trainer._execute_code = original_execute_code\n",
808
+ "\n",
809
+ "# Run the test\n",
810
+ "test_retool_core_functionality()"
811
+ ],
812
+ "metadata": {
813
+ "colab": {
814
+ "base_uri": "https://localhost:8080/"
815
+ },
816
+ "id": "MumcLxASaBkj",
817
+ "outputId": "21dd5d29-f402-4837-be86-08cee4b9d7a2"
818
+ },
819
+ "execution_count": 25,
820
+ "outputs": [
821
+ {
822
+ "output_type": "stream",
823
+ "name": "stdout",
824
+ "text": [
825
+ "Using device: cuda\n"
826
+ ]
827
+ },
828
+ {
829
+ "output_type": "stream",
830
+ "name": "stderr",
831
+ "text": [
832
+ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
833
+ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
834
+ ]
835
+ },
836
+ {
837
+ "output_type": "stream",
838
+ "name": "stdout",
839
+ "text": [
840
+ "Code tokens: 50257, 50258\n",
841
+ "Interpreter tokens: 50259, 50260\n",
842
+ "\n",
843
+ "--- Testing Core Functionality ---\n",
844
+ "Found code block: def fibonacci(n):\n",
845
+ " a, b = 0, 1\n",
846
+ " result = []\n",
847
+ " for _ in range(n):\n",
848
+ " result.append(a)\n",
849
+ " a, b = b, a + b\n",
850
+ " return result\n",
851
+ "\n",
852
+ "print(fibonacci(5))\n",
853
+ "Code execution result: 0 1 1 2 3\n",
854
+ "Generated continuation: 0 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48\n",
855
+ "\n",
856
+ "--- Final Results ---\n",
857
+ "Completion:\n",
858
+ "<interpreter>0 1 1 2 3</interpreter>0 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48\n",
859
+ "\n",
860
+ "Interpreter positions: [(0, 6)]\n",
861
+ "\n",
862
+ "--- Testing ReToolTrainer with Direct Injection ---\n",
863
+ "Testing code execution and continuation...\n",
864
+ "Executing code block: def fibonacci(n):\n",
865
+ " a, b = 0, 1\n",
866
+ " result = []\n",
867
+ " for _ in range(n):\n",
868
+ " result.append(a)\n",
869
+ " a, b = b, a + b\n",
870
+ " return result\n",
871
+ "\n",
872
+ "print(fibonacci(5))\n",
873
+ "Mock executing code: def fibonacci(n):\n",
874
+ " a, b = 0, 1\n",
875
+ " result = []\n",
876
+ " for _ in range(n):\n",
877
+ " result.append(a)\n",
878
+ " a, b = b, a + b\n",
879
+ " return result\n",
880
+ "\n",
881
+ "print(fibonacci(5))\n"
882
+ ]
883
+ },
884
+ {
885
+ "output_type": "stream",
886
+ "name": "stderr",
887
+ "text": [
888
+ "/tmp/ipython-input-20-2039368761.py:57: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `ReToolTrainer.__init__`. Use `processing_class` instead.\n",
889
+ " super().__init__(\n"
890
+ ]
891
+ },
892
+ {
893
+ "output_type": "stream",
894
+ "name": "stdout",
895
+ "text": [
896
+ "\n",
897
+ "--- Trainer Test Results ---\n",
898
+ "Completion:\n",
899
+ "<interpreter>0 1 1 2 3</interpreter>1 1 1 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n",
900
+ "\n",
901
+ "Interpreter positions: [(0, 6)]\n"
902
+ ]
903
+ }
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "code",
908
+ "source": [],
909
+ "metadata": {
910
+ "id": "2KBJZTXOaCA3"
911
+ },
912
+ "execution_count": null,
913
+ "outputs": []
914
+ },
915
+ {
916
+ "cell_type": "code",
917
+ "source": [
918
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache\n",
919
+ "\n",
920
+ "def test_direct_kv_cache_usage():\n",
921
+ " # 1. Setup model, tokenizer, and device\n",
922
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
923
+ " import torch\n",
924
+ "\n",
925
+ " # Use a model that fits in memory\n",
926
+ " model_name = \"gpt2-medium\"\n",
927
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
928
+ "\n",
929
+ " # Check device\n",
930
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
931
+ " print(f\"Using device: {device}\")\n",
932
+ "\n",
933
+ " # Load model to device\n",
934
+ " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n",
935
+ "\n",
936
+ " # 2. Manual token-by-token generation with KV caching\n",
937
+ " def generate_with_manual_kv_cache(input_ids, num_tokens=20):\n",
938
+ " \"\"\"Generate tokens one by one with manual KV cache management\"\"\"\n",
939
+ " current_ids = input_ids.clone()\n",
940
+ " past_key_values = None\n",
941
+ "\n",
942
+ " generated_tokens = []\n",
943
+ "\n",
944
+ " for _ in range(num_tokens):\n",
945
+ " # Forward pass with past_key_values\n",
946
+ " with torch.no_grad():\n",
947
+ " outputs = model(\n",
948
+ " input_ids=current_ids if past_key_values is None else current_ids[:, -1:],\n",
949
+ " #past_key_values=past_key_values,\n",
950
+ " past_key_values= DynamicCache.from_legacy_cache(past_key_values),\n",
951
+ " use_cache=True\n",
952
+ " )\n",
953
+ "\n",
954
+ " # Get logits for the next token (last position)\n",
955
+ " next_token_logits = outputs.logits[:, -1, :]\n",
956
+ "\n",
957
+ " # Sample from the distribution\n",
958
+ " probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)\n",
959
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
960
+ "\n",
961
+ " # Add to generated tokens\n",
962
+ " generated_tokens.append(next_token.item())\n",
963
+ "\n",
964
+ " # Update current_ids for next iteration\n",
965
+ " current_ids = torch.cat([current_ids, next_token], dim=1)\n",
966
+ "\n",
967
+ " # Update past_key_values\n",
968
+ " past_key_values = outputs.past_key_values\n",
969
+ " #print('after generation, past_key_values ', past_key_values)\n",
970
+ "\n",
971
+ " return generated_tokens, past_key_values\n",
972
+ "\n",
973
+ " # 3. Run multi-turn generation with manual KV cache\n",
974
+ " def run_manual_multi_turn_generation():\n",
975
+ " # Start with a prompt\n",
976
+ " prompt = \"Once upon a time, in a magical forest, there lived a\"\n",
977
+ "\n",
978
+ " # Tokenize the prompt\n",
979
+ " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
980
+ "\n",
981
+ " # Initialize tracking\n",
982
+ " full_text = prompt\n",
983
+ " current_ids = prompt_ids\n",
984
+ " past_kv = None\n",
985
+ "\n",
986
+ " # Generate in multiple turns\n",
987
+ " for turn_idx in range(3): # 3 turns\n",
988
+ " print(f\"\\n==== Turn {turn_idx + 1} ====\")\n",
989
+ " print(f\"Current input: {tokenizer.decode(current_ids[0])}\")\n",
990
+ " print(f\"KV cache present: {past_kv is not None}\")\n",
991
+ "\n",
992
+ " # Pause for inspection\n",
993
+ " input(\"Press Enter to generate next part...\")\n",
994
+ "\n",
995
+ " # Generate new tokens manually\n",
996
+ " new_token_ids, past_kv = generate_with_manual_kv_cache(current_ids, num_tokens=20)\n",
997
+ "\n",
998
+ " # Decode and display new tokens\n",
999
+ " new_text = tokenizer.decode(new_token_ids)\n",
1000
+ " print(f\"Generated: {new_text}\")\n",
1001
+ "\n",
1002
+ " # Accumulate\n",
1003
+ " full_text += new_text\n",
1004
+ "\n",
1005
+ " # Now inject a custom continuation\n",
1006
+ " custom_text = \" Suddenly, a rainbow appeared in the sky!\"\n",
1007
+ " custom_ids = tokenizer.encode(custom_text, return_tensors=\"pt\").to(device)\n",
1008
+ "\n",
1009
+ " print(f\"\\n==== Injecting custom text: {custom_text} ====\")\n",
1010
+ "\n",
1011
+ " # Update tracking\n",
1012
+ " full_text += custom_text\n",
1013
+ "\n",
1014
+ " # Prepare for next turn - start with the custom text\n",
1015
+ " current_ids = custom_ids\n",
1016
+ " # Keep past_kv from previous generation\n",
1017
+ "\n",
1018
+ " print(f\"Full text so far: {full_text}\")\n",
1019
+ " print(\"-\" * 50)\n",
1020
+ "\n",
1021
+ " print(\"\\n==== Final Story ====\")\n",
1022
+ " print(full_text)\n",
1023
+ "\n",
1024
+ " return full_text\n",
1025
+ "\n",
1026
+ " # 4. Run the test\n",
1027
+ " try:\n",
1028
+ " print(\"\\n=== Testing Manual KV Cache Usage ===\\n\")\n",
1029
+ "\n",
1030
+ " story = run_manual_multi_turn_generation()\n",
1031
+ "\n",
1032
+ " print(\"\\n=== Test Complete ===\")\n",
1033
+ "\n",
1034
+ " except Exception as e:\n",
1035
+ " import traceback\n",
1036
+ " print(f\"Error during testing: {e}\")\n",
1037
+ " traceback.print_exc()\n",
1038
+ "\n",
1039
+ "# Run the test\n",
1040
+ "test_direct_kv_cache_usage()"
1041
+ ],
1042
+ "metadata": {
1043
+ "colab": {
1044
+ "base_uri": "https://localhost:8080/"
1045
+ },
1046
+ "id": "p2_BiDS0IlPl",
1047
+ "outputId": "f62d0e49-8f27-4c9b-aeac-dbe1fe1def21"
1048
+ },
1049
+ "execution_count": 3,
1050
+ "outputs": [
1051
+ {
1052
+ "output_type": "stream",
1053
+ "name": "stdout",
1054
+ "text": [
1055
+ "Using device: cuda\n",
1056
+ "\n",
1057
+ "=== Testing Manual KV Cache Usage ===\n",
1058
+ "\n",
1059
+ "\n",
1060
+ "==== Turn 1 ====\n",
1061
+ "Current input: Once upon a time, in a magical forest, there lived a\n",
1062
+ "KV cache present: False\n",
1063
+ "Press Enter to generate next part...\n",
1064
+ "Generated: wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to\n",
1065
+ "\n",
1066
+ "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n",
1067
+ "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky!\n",
1068
+ "--------------------------------------------------\n",
1069
+ "\n",
1070
+ "==== Turn 2 ====\n",
1071
+ "Current input: Suddenly, a rainbow appeared in the sky!\n",
1072
+ "KV cache present: True\n",
1073
+ "Press Enter to generate next part...\n",
1074
+ "Generated: Although it was a little less than a hundred meters wide, it was enough to cover the entire sky\n",
1075
+ "\n",
1076
+ "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n",
1077
+ "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky!\n",
1078
+ "--------------------------------------------------\n",
1079
+ "\n",
1080
+ "==== Turn 3 ====\n",
1081
+ "Current input: Suddenly, a rainbow appeared in the sky!\n",
1082
+ "KV cache present: True\n",
1083
+ "Press Enter to generate next part...\n",
1084
+ "Generated: A rainbow!\n",
1085
+ "\n",
1086
+ "I thought about it, and after seeing how many people were in the sky\n",
1087
+ "\n",
1088
+ "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n",
1089
+ "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky! A rainbow!\n",
1090
+ "\n",
1091
+ "I thought about it, and after seeing how many people were in the sky Suddenly, a rainbow appeared in the sky!\n",
1092
+ "--------------------------------------------------\n",
1093
+ "\n",
1094
+ "==== Final Story ====\n",
1095
+ "Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky! A rainbow!\n",
1096
+ "\n",
1097
+ "I thought about it, and after seeing how many people were in the sky Suddenly, a rainbow appeared in the sky!\n",
1098
+ "\n",
1099
+ "=== Test Complete ===\n"
1100
+ ]
1101
+ }
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "source": [
1107
+ "def test_retool_with_working_kv_cache():\n",
1108
+ " # 1. Setup model, tokenizer, and device\n",
1109
+ " from transformers import AutoModelForCausalLM, AutoTokenizer\n",
1110
+ " import torch\n",
1111
+ " import re\n",
1112
+ "\n",
1113
+ " # Use a model that fits in memory\n",
1114
+ " model_name = \"gpt2-medium\"\n",
1115
+ " tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
1116
+ "\n",
1117
+ " # Check device\n",
1118
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1119
+ " print(f\"Using device: {device}\")\n",
1120
+ "\n",
1121
+ " # Load model to device\n",
1122
+ " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n",
1123
+ "\n",
1124
+ " # 2. Add special tokens\n",
1125
+ " special_tokens = {\n",
1126
+ " 'additional_special_tokens': ['<code>', '</code>', '<interpreter>', '</interpreter>']\n",
1127
+ " }\n",
1128
+ " tokenizer.add_special_tokens(special_tokens)\n",
1129
+ " model.resize_token_embeddings(len(tokenizer))\n",
1130
+ "\n",
1131
+ " # Get token IDs\n",
1132
+ " code_start_id = tokenizer.convert_tokens_to_ids('<code>')\n",
1133
+ " code_end_id = tokenizer.convert_tokens_to_ids('</code>')\n",
1134
+ " interpreter_start_id = tokenizer.convert_tokens_to_ids('<interpreter>')\n",
1135
+ " interpreter_end_id = tokenizer.convert_tokens_to_ids('</interpreter>')\n",
1136
+ "\n",
1137
+ " print(f\"EOS token ID: {tokenizer.eos_token_id}\")\n",
1138
+ " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n",
1139
+ " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n",
1140
+ "\n",
1141
+ " # 3. Manual token generation with KV caching\n",
1142
+ " def generate_with_manual_kv_cache(input_ids, past_key_values=None, max_tokens=20, stop_ids=None):\n",
1143
+ " \"\"\"Generate tokens with KV cache until a stop token or max_tokens is reached\"\"\"\n",
1144
+ " if stop_ids is None:\n",
1145
+ " stop_ids = [tokenizer.eos_token_id]\n",
1146
+ "\n",
1147
+ " current_ids = input_ids.clone()\n",
1148
+ " generated_tokens = []\n",
1149
+ "\n",
1150
+ " for _ in range(max_tokens):\n",
1151
+ " # Forward pass with past_key_values\n",
1152
+ " with torch.no_grad():\n",
1153
+ " outputs = model(\n",
1154
+ " input_ids=current_ids if past_key_values is None else current_ids[:, -1:],\n",
1155
+ " past_key_values=past_key_values,\n",
1156
+ " use_cache=True\n",
1157
+ " )\n",
1158
+ "\n",
1159
+ " # Get logits for the next token\n",
1160
+ " next_token_logits = outputs.logits[:, -1, :]\n",
1161
+ "\n",
1162
+ " # Sample from the distribution\n",
1163
+ " probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)\n",
1164
+ " next_token = torch.multinomial(probs, num_samples=1)\n",
1165
+ "\n",
1166
+ " # Get the token ID\n",
1167
+ " token_id = next_token.item()\n",
1168
+ "\n",
1169
+ " # Add to generated tokens\n",
1170
+ " generated_tokens.append(token_id)\n",
1171
+ "\n",
1172
+ " # Update current_ids for next iteration\n",
1173
+ " current_ids = torch.cat([current_ids, next_token], dim=1)\n",
1174
+ "\n",
1175
+ " # Update past_key_values\n",
1176
+ " past_key_values = outputs.past_key_values\n",
1177
+ "\n",
1178
+ " # Check if we hit a stop token\n",
1179
+ " if token_id in stop_ids:\n",
1180
+ " break\n",
1181
+ "\n",
1182
+ " # Convert list of token IDs to tensor\n",
1183
+ " result_tensor = torch.tensor([generated_tokens], device=device)\n",
1184
+ " return result_tensor, past_key_values\n",
1185
+ "\n",
1186
+ " # 4. ReTool simulation with working KV cache\n",
1187
+ " def simulate_retool_with_working_kv_cache(prompt, max_turns=3):\n",
1188
+ " \"\"\"Simulate the ReTool process with working KV cache\"\"\"\n",
1189
+ " # Tokenize the prompt\n",
1190
+ " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n",
1191
+ "\n",
1192
+ " # Initialize tracking\n",
1193
+ " full_sequence = prompt_ids.clone()\n",
1194
+ " completion = torch.empty((1, 0), dtype=torch.long, device=device)\n",
1195
+ " interpreter_positions = []\n",
1196
+ "\n",
1197
+ " # Keep the KV cache from previous turns\n",
1198
+ " past_kv = None\n",
1199
+ "\n",
1200
+ " for turn_idx in range(max_turns):\n",
1201
+ " print(f\"\\n==== Turn {turn_idx + 1} ====\")\n",
1202
+ "\n",
1203
+ " # Determine what to generate from\n",
1204
+ " if turn_idx == 0:\n",
1205
+ " # First turn - generate from the prompt\n",
1206
+ " current_input = full_sequence\n",
1207
+ " print(f\"Generating from prompt: {tokenizer.decode(current_input[0])}\")\n",
1208
+ " else:\n",
1209
+ " # Later turns - might be generating from interpreter output\n",
1210
+ " current_input = full_sequence[:, -20:] if full_sequence.size(1) > 20 else full_sequence\n",
1211
+ " print(f\"Generating from: {tokenizer.decode(current_input[0])}\")\n",
1212
+ "\n",
1213
+ " # Generate with manual KV cache\n",
1214
+ " new_tokens, past_kv = generate_with_manual_kv_cache(\n",
1215
+ " current_input,\n",
1216
+ " past_key_values=past_kv,\n",
1217
+ " max_tokens=30,\n",
1218
+ " stop_ids=[tokenizer.eos_token_id, code_end_id]\n",
1219
+ " )\n",
1220
+ "\n",
1221
+ " # Decode and display\n",
1222
+ " new_text = tokenizer.decode(new_tokens[0])\n",
1223
+ " print(f\"Generated: {new_text}\")\n",
1224
+ "\n",
1225
+ " # Update tracking\n",
1226
+ " full_sequence = torch.cat([full_sequence, new_tokens], dim=1)\n",
1227
+ " completion = torch.cat([completion, new_tokens], dim=1)\n",
1228
+ "\n",
1229
+ " # Check for code blocks\n",
1230
+ " full_text = tokenizer.decode(full_sequence[0])\n",
1231
+ " code_blocks = re.findall(r'<code>(.*?)</code>', full_text, re.DOTALL)\n",
1232
+ "\n",
1233
+ " # Pause for inspection\n",
1234
+ " input(\"Press Enter to continue...\")\n",
1235
+ "\n",
1236
+ " if code_blocks and code_end_id in new_tokens[0]:\n",
1237
+ " print(\"\\n==== Found code block! ====\")\n",
1238
+ " # Get the last code block\n",
1239
+ " code_block = code_blocks[-1].strip()\n",
1240
+ " print(f\"Code block: {code_block}\")\n",
1241
+ "\n",
1242
+ " # Mock code execution\n",
1243
+ " print(\"\\n==== Executing code ====\")\n",
1244
+ " interpreter_output = \"0 1 1 2 3\"\n",
1245
+ " print(f\"Execution result: {interpreter_output}\")\n",
1246
+ "\n",
1247
+ " # Format interpreter feedback\n",
1248
+ " interpreter_text = f\"<interpreter>{interpreter_output}</interpreter>\"\n",
1249
+ " interpreter_ids = tokenizer.encode(\n",
1250
+ " interpreter_text,\n",
1251
+ " return_tensors=\"pt\",\n",
1252
+ " add_special_tokens=False\n",
1253
+ " ).to(device)\n",
1254
+ "\n",
1255
+ " # Record positions\n",
1256
+ " start_idx = completion.size(1)\n",
1257
+ " completion = torch.cat([completion, interpreter_ids], dim=1)\n",
1258
+ " end_idx = completion.size(1) - 1\n",
1259
+ " interpreter_positions.append((start_idx, end_idx))\n",
1260
+ "\n",
1261
+ " # Add to full sequence\n",
1262
+ " full_sequence = torch.cat([full_sequence, interpreter_ids], dim=1)\n",
1263
+ " print(f\"Added interpreter output: {interpreter_text}\")\n",
1264
+ "\n",
1265
+ " # We're still using the same past_kv for the next turn\n",
1266
+ " # The next input will be the interpreter output\n",
1267
+ " elif tokenizer.eos_token_id in new_tokens[0]:\n",
1268
+ " print(\"Found EOS token, ending generation\")\n",
1269
+ " break\n",
1270
+ "\n",
1271
+ " return completion, interpreter_positions\n",
1272
+ "\n",
1273
+ " # 5. Test with a prompt containing a code block\n",
1274
+ " prompt = \"\"\"Let me solve this problem with code:\n",
1275
+ "\n",
1276
+ "<code>\n",
1277
+ "def fibonacci(n):\n",
1278
+ " a, b = 0, 1\n",
1279
+ " result = []\n",
1280
+ " for _ in range(n):\n",
1281
+ " result.append(a)\n",
1282
+ " a, b = b, a + b\n",
1283
+ " return result\n",
1284
+ "\n",
1285
+ "print(fibonacci(5))\n",
1286
+ "</code>\"\"\"\n",
1287
+ "\n",
1288
+ " # 6. Run the test\n",
1289
+ " try:\n",
1290
+ " print(\"\\n=== Testing ReTool with Working KV Cache ===\\n\")\n",
1291
+ "\n",
1292
+ " completion, positions = simulate_retool_with_working_kv_cache(prompt)\n",
1293
+ "\n",
1294
+ " print(\"\\n=== Final Results ===\\n\")\n",
1295
+ " print(\"Generated completion:\")\n",
1296
+ " print(tokenizer.decode(completion[0]))\n",
1297
+ "\n",
1298
+ " print(\"\\nFull text:\")\n",
1299
+ " print(tokenizer.decode(torch.cat([tokenizer.encode(prompt, return_tensors=\"pt\")[0].to(device), completion[0]])))\n",
1300
+ "\n",
1301
+ " print(\"\\nInterpreter positions:\", positions)\n",
1302
+ "\n",
1303
+ " except Exception as e:\n",
1304
+ " import traceback\n",
1305
+ " print(f\"Error during testing: {e}\")\n",
1306
+ " traceback.print_exc()\n",
1307
+ "\n",
1308
+ "# Run the test\n",
1309
+ "test_retool_with_working_kv_cache()"
1310
+ ],
1311
+ "metadata": {
1312
+ "colab": {
1313
+ "base_uri": "https://localhost:8080/"
1314
+ },
1315
+ "id": "T6_ob3S4M5mn",
1316
+ "outputId": "e5f42a03-c49a-403f-d27b-0ae50ecd095e"
1317
+ },
1318
+ "execution_count": 4,
1319
+ "outputs": [
1320
+ {
1321
+ "output_type": "stream",
1322
+ "name": "stdout",
1323
+ "text": [
1324
+ "Using device: cuda\n"
1325
+ ]
1326
+ },
1327
+ {
1328
+ "output_type": "stream",
1329
+ "name": "stderr",
1330
+ "text": [
1331
+ "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
1332
+ ]
1333
+ },
1334
+ {
1335
+ "output_type": "stream",
1336
+ "name": "stdout",
1337
+ "text": [
1338
+ "EOS token ID: 50256\n",
1339
+ "Code tokens: 50257, 50258\n",
1340
+ "Interpreter tokens: 50259, 50260\n",
1341
+ "\n",
1342
+ "=== Testing ReTool with Working KV Cache ===\n",
1343
+ "\n",
1344
+ "\n",
1345
+ "==== Turn 1 ====\n",
1346
+ "Generating from prompt: Let me solve this problem with code:\n",
1347
+ "\n",
1348
+ "<code>\n",
1349
+ "def fibonacci(n):\n",
1350
+ " a, b = 0, 1\n",
1351
+ " result = []\n",
1352
+ " for _ in range(n):\n",
1353
+ " result.append(a)\n",
1354
+ " a, b = b, a + b\n",
1355
+ " return result\n",
1356
+ "\n",
1357
+ "print(fibonacci(5))\n",
1358
+ "</code>\n",
1359
+ "Generated: \n",
1360
+ "def fibonacci(n):\n",
1361
+ "\n",
1362
+ " a, b = 0, 1\n",
1363
+ "\n",
1364
+ " result = [0,\n",
1365
+ "Press Enter to continue...\n",
1366
+ "\n",
1367
+ "==== Turn 2 ====\n",
1368
+ "Generating from: a, b = 0, 1\n",
1369
+ "\n",
1370
+ " result = [0,\n",
1371
+ "Generated: 0, 0, 1]\n",
1372
+ "\n",
1373
+ " a, b = b, a + b\n",
1374
+ "\n",
1375
+ "\n",
1376
+ "ret = [0,\n",
1377
+ "Press Enter to continue...\n",
1378
+ "\n",
1379
+ "==== Turn 3 ====\n",
1380
+ "Generating from: a, b = b, a + b\n",
1381
+ "\n",
1382
+ "\n",
1383
+ "ret = [0,\n",
1384
+ "Generated: 1, 1, 1]\n",
1385
+ "\n",
1386
+ "for i,j in enumerate(n, fibonacci(n-1, 1-f\n",
1387
+ "Press Enter to continue...\n",
1388
+ "\n",
1389
+ "=== Final Results ===\n",
1390
+ "\n",
1391
+ "Generated completion:\n",
1392
+ "\n",
1393
+ "def fibonacci(n):\n",
1394
+ "\n",
1395
+ " a, b = 0, 1\n",
1396
+ "\n",
1397
+ " result = [0, 0, 0, 1]\n",
1398
+ "\n",
1399
+ " a, b = b, a + b\n",
1400
+ "\n",
1401
+ "\n",
1402
+ "ret = [0, 1, 1, 1]\n",
1403
+ "\n",
1404
+ "for i,j in enumerate(n, fibonacci(n-1, 1-f\n",
1405
+ "\n",
1406
+ "Full text:\n",
1407
+ "Let me solve this problem with code:\n",
1408
+ "\n",
1409
+ "<code>\n",
1410
+ "def fibonacci(n):\n",
1411
+ " a, b = 0, 1\n",
1412
+ " result = []\n",
1413
+ " for _ in range(n):\n",
1414
+ " result.append(a)\n",
1415
+ " a, b = b, a + b\n",
1416
+ " return result\n",
1417
+ "\n",
1418
+ "print(fibonacci(5))\n",
1419
+ "</code>\n",
1420
+ "def fibonacci(n):\n",
1421
+ "\n",
1422
+ " a, b = 0, 1\n",
1423
+ "\n",
1424
+ " result = [0, 0, 0, 1]\n",
1425
+ "\n",
1426
+ " a, b = b, a + b\n",
1427
+ "\n",
1428
+ "\n",
1429
+ "ret = [0, 1, 1, 1]\n",
1430
+ "\n",
1431
+ "for i,j in enumerate(n, fibonacci(n-1, 1-f\n",
1432
+ "\n",
1433
+ "Interpreter positions: []\n"
1434
+ ]
1435
+ }
1436
+ ]
1437
+ },
1438
+ {
1439
+ "cell_type": "code",
1440
+ "source": [],
1441
+ "metadata": {
1442
+ "id": "YFIXEa5fM5px"
1443
+ },
1444
+ "execution_count": null,
1445
+ "outputs": []
1446
+ },
1447
+ {
1448
+ "cell_type": "code",
1449
+ "source": [],
1450
+ "metadata": {
1451
+ "id": "FjaszXJOIlVz"
1452
+ },
1453
+ "execution_count": null,
1454
+ "outputs": []
1455
+ },
1456
+ {
1457
+ "cell_type": "code",
1458
+ "source": [],
1459
+ "metadata": {
1460
+ "id": "xgjX6_xZaCDQ"
1461
+ },
1462
+ "execution_count": null,
1463
+ "outputs": []
1464
+ },
1465
+ {
1466
+ "cell_type": "code",
1467
+ "source": [],
1468
+ "metadata": {
1469
+ "id": "iTGXE8lRaCF4"
1470
+ },
1471
+ "execution_count": null,
1472
+ "outputs": []
1473
+ },
1474
+ {
1475
+ "cell_type": "code",
1476
+ "source": [],
1477
+ "metadata": {
1478
+ "id": "oM5BSZHEaCIx"
1479
+ },
1480
+ "execution_count": null,
1481
+ "outputs": []
1482
+ },
1483
+ {
1484
+ "cell_type": "markdown",
1485
+ "metadata": {
1486
+ "id": "7d252539"
1487
+ },
1488
+ "source": [
1489
+ "**1. Clear CUDA Cache:**\n",
1490
+ "\n",
1491
+ "This is often the first thing to try when you get a CUDA OOM error."
1492
+ ]
1493
+ },
1494
+ {
1495
+ "cell_type": "code",
1496
+ "source": [],
1497
+ "metadata": {
1498
+ "id": "YhKSjnxiaBCb"
1499
+ },
1500
+ "execution_count": null,
1501
+ "outputs": []
1502
+ },
1503
+ {
1504
+ "cell_type": "code",
1505
+ "metadata": {
1506
+ "colab": {
1507
+ "base_uri": "https://localhost:8080/"
1508
+ },
1509
+ "id": "f793cb16",
1510
+ "outputId": "3b5b2b99-2e9b-44a2-88df-7293e51de014"
1511
+ },
1512
+ "source": [
1513
+ "import torch\n",
1514
+ "\n",
1515
+ "if torch.cuda.is_available():\n",
1516
+ " torch.cuda.empty_cache()\n",
1517
+ " print(\"CUDA cache cleared!\")\n",
1518
+ "else:\n",
1519
+ " print(\"CUDA not available, no cache to clear.\")"
1520
+ ],
1521
+ "execution_count": 18,
1522
+ "outputs": [
1523
+ {
1524
+ "output_type": "stream",
1525
+ "name": "stdout",
1526
+ "text": [
1527
+ "CUDA cache cleared!\n"
1528
+ ]
1529
+ }
1530
+ ]
1531
+ },
1532
+ {
1533
+ "cell_type": "markdown",
1534
+ "metadata": {
1535
+ "id": "d25e30fe"
1536
+ },
1537
+ "source": [
1538
+ "**2. Delete Large Variables and Run Garbage Collection:**\n",
1539
+ "\n",
1540
+ "Identify variables holding large objects (like models, tensors, dataframes) that you don't need anymore and delete them. Then explicitly run garbage collection."
1541
+ ]
1542
+ },
1543
+ {
1544
+ "cell_type": "code",
1545
+ "metadata": {
1546
+ "colab": {
1547
+ "base_uri": "https://localhost:8080/"
1548
+ },
1549
+ "id": "02474dce",
1550
+ "outputId": "80223089-31f7-485f-8490-aad00d97277a"
1551
+ },
1552
+ "source": [
1553
+ "# Example: if you have a large model or tensor named 'model' or 'data'\n",
1554
+ "# del model\n",
1555
+ "# del data\n",
1556
+ "\n",
1557
+ "import gc\n",
1558
+ "gc.collect()\n",
1559
+ "\n",
1560
+ "print(\"Garbage collection complete.\")"
1561
+ ],
1562
+ "execution_count": 19,
1563
+ "outputs": [
1564
+ {
1565
+ "output_type": "stream",
1566
+ "name": "stdout",
1567
+ "text": [
1568
+ "Garbage collection complete.\n"
1569
+ ]
1570
+ }
1571
+ ]
1572
+ },
1573
+ {
1574
+ "cell_type": "markdown",
1575
+ "metadata": {
1576
+ "id": "105cefce"
1577
+ },
1578
+ "source": [
1579
+ "**3. Restart Runtime:**\n",
1580
+ "\n",
1581
+ "If the above steps don't work, restarting the runtime is the most drastic but often most effective way to clear all memory. Go to the Colab menu: `Runtime` -> `Restart runtime`."
1582
+ ]
1583
+ }
1584
+ ]
1585
+ }