{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ejiXlq27sck1", "outputId": "d2c846e5-97da-4533-d23f-1cb876d67069" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.52.4)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.18.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.32.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2.0.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.1)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.3)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (2025.3.2)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (1.1.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.4.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2025.4.26)\n" ] } ], "source": [ "! pip install transformers" ] }, { "cell_type": "code", "source": [ "! pip install profiling-decorator" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3Sa_Bpi1srA9", "outputId": "6ad4ffd6-1058-4097-acb2-21978fe27ca0" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting profiling-decorator\n", " Downloading profiling_decorator-0.0.6-py3-none-any.whl.metadata (6.2 kB)\n", "Downloading profiling_decorator-0.0.6-py3-none-any.whl (9.2 kB)\n", "Installing collected packages: profiling-decorator\n", "Successfully installed profiling-decorator-0.0.6\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "9IRlvyF-J4Mp" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "eV3CXXy6J47P" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def test_updated_retool_implementation():\n", " # 1. Setup model, tokenizer, and device\n", " from transformers import AutoModelForCausalLM, AutoTokenizer\n", " import torch\n", " import transformers\n", " import re\n", " from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache\n", "\n", " # Use a model that fits in memory\n", " model_name = \"gpt2-medium\"\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Ensure padding token is set\n", " if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", " # Check device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"Using device: {device}\")\n", "\n", " # Load model to device\n", " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", "\n", " # 2. Add special tokens\n", " special_tokens = {\n", " 'additional_special_tokens': ['', '', '', '']\n", " }\n", " tokenizer.add_special_tokens(special_tokens)\n", " model.resize_token_embeddings(len(tokenizer))\n", "\n", " # Get token IDs\n", " code_start_id = tokenizer.convert_tokens_to_ids('')\n", " code_end_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_start_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_end_id = tokenizer.convert_tokens_to_ids('')\n", "\n", " print(f\"EOS token ID: {tokenizer.eos_token_id}\")\n", " print(f\"Pad token ID: {tokenizer.pad_token_id}\")\n", " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n", " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n", "\n", " # 3. Create a test version of your ReToolTrainer with custom generation\n", " class TestReToolTrainer:\n", " def __init__(self, model, tokenizer, device):\n", " self.model = model\n", " self.processing_class = tokenizer\n", " self.device = device\n", " self.temperature = 0.7\n", " self.top_p = 0.9\n", " self.top_k = 50\n", "\n", " # Ensure pad token is set\n", " if self.processing_class.pad_token is None:\n", " self.processing_class.pad_token = self.processing_class.eos_token\n", "\n", " def _execute_code(self, code_block):\n", " \"\"\"Mock code execution\"\"\"\n", " print(f\"\\n==== EXECUTING CODE ====\")\n", " print(f\"{code_block}\")\n", " print(f\"========================\\n\")\n", " return \"0 1 1 2 3\"\n", "\n", " def _custom_generate(self, input_ids, attention_mask=None, past_key_values=None, max_new_tokens=50, eos_token_ids=None):\n", " \"\"\"Custom generation function that avoids KV cache issues\"\"\"\n", " if attention_mask is None:\n", " attention_mask = torch.ones_like(input_ids)\n", "\n", " if eos_token_ids is None:\n", " eos_token_ids = [self.processing_class.eos_token_id]\n", "\n", " # Initialize\n", " current_ids = input_ids.clone()\n", " current_mask = attention_mask.clone()\n", " current_kv = past_key_values\n", "\n", " # Generate tokens in batches for efficiency\n", " all_tokens = []\n", " batch_size = 10 # Process this many tokens at once\n", "\n", " for start_idx in range(0, max_new_tokens, batch_size):\n", " # How many tokens to generate in this batch\n", " batch_tokens = min(batch_size, max_new_tokens - start_idx)\n", "\n", " # Accumulate new tokens\n", " new_tokens = []\n", "\n", " for _ in range(batch_tokens):\n", " # Forward pass with proper cache handling\n", " with torch.no_grad():\n", " outputs = self.model(\n", " input_ids=current_ids if current_kv is None else current_ids[:, -1:],\n", " attention_mask=current_mask if current_kv is None else current_mask[:, -1:],\n", " past_key_values=DynamicCache.from_legacy_cache(current_kv) if current_kv is not None else None,\n", " use_cache=True\n", " )\n", "\n", " # Sample next token\n", " next_token_logits = outputs.logits[:, -1, :] / self.temperature\n", " filtered_logits = self._filter_logits(next_token_logits)\n", " probs = torch.nn.functional.softmax(filtered_logits, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " # Add to accumulated tokens\n", " token_id = next_token.item()\n", " new_tokens.append(token_id)\n", "\n", " # Update for next iteration\n", " current_ids = torch.cat([current_ids, next_token], dim=1)\n", " token_mask = torch.ones((1, 1), device=current_mask.device, dtype=current_mask.dtype)\n", " current_mask = torch.cat([current_mask, token_mask], dim=1)\n", " current_kv = outputs.past_key_values\n", "\n", " # Check for stop tokens - include both EOS and code_end\n", " if token_id in eos_token_ids:\n", " break\n", "\n", " # Add batch tokens to overall result\n", " all_tokens.extend(new_tokens)\n", "\n", " # Check if we hit a stop token\n", " if len(new_tokens) < batch_tokens:\n", " break\n", "\n", " # Convert to tensor\n", " result = torch.tensor([all_tokens], device=input_ids.device)\n", " return result, current_kv\n", "\n", " def _filter_logits(self, logits):\n", " \"\"\"Apply top-k and top-p filtering\"\"\"\n", " if self.top_k > 0:\n", " top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)\n", " logits[0, :] = torch.full_like(logits[0, :], float('-inf'))\n", " logits[0, top_k_indices[0]] = top_k_logits[0]\n", "\n", " if self.top_p < 1.0:\n", " sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)\n", " cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)\n", "\n", " # Remove tokens with cumulative probability above threshold\n", " sorted_indices_to_remove = cumulative_probs > self.top_p\n", " # Shift the indices to the right to keep the first token above threshold\n", " sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()\n", " sorted_indices_to_remove[:, 0] = 0\n", "\n", " # Scatter sorted tensors to original indexing\n", " indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)\n", " logits[indices_to_remove] = float('-inf')\n", "\n", " return logits\n", "\n", " def _retool_generate_with_interpreter(self, prompt_ids_batch, attention_mask_batch, eos_id, interpreter_id, code_id, max_turns=10):\n", " \"\"\"Your updated implementation with custom generation\"\"\"\n", " batch_size = prompt_ids_batch.size(0)\n", " batch_completion = []\n", " batch_interpreter_positions = []\n", "\n", " for i in range(batch_size):\n", " print(f\"Processing batch item {i+1}/{batch_size}\")\n", "\n", " # Initialize\n", " current_input_id = prompt_ids_batch[i:i+1]\n", " current_attention_mask = attention_mask_batch[i:i+1]\n", " current_kv = None\n", "\n", " # Track the completion part (no prompt)\n", " cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=prompt_ids_batch.device)\n", " interpreter_positions = []\n", "\n", " for turn_idx in range(max_turns):\n", " # Check if input is empty\n", " if current_input_id.size(1) == 0:\n", " print(f\"Turn {turn_idx + 1}: Input is empty, breaking loop\")\n", " break\n", "\n", " print(f\"\\n--- Turn {turn_idx + 1} ---\")\n", " print(f\"Current input: {self.processing_class.decode(current_input_id[0])}\")\n", " print(f\"KV cache present: {current_kv is not None}\")\n", "\n", " # Generate with custom function\n", " newly_generated_tokens, current_kv = self._custom_generate(\n", " input_ids=current_input_id,\n", " attention_mask=current_attention_mask,\n", " past_key_values=current_kv,\n", " max_new_tokens=30,\n", " eos_token_ids=[eos_id, code_id[1]]\n", " )\n", "\n", " # Display generated text\n", " print(f\"Generated: {self.processing_class.decode(newly_generated_tokens[0])}\")\n", "\n", " # Add to cumulative completion\n", " cumulative_completion_ids = torch.cat([cumulative_completion_ids, newly_generated_tokens], dim=1)\n", "\n", " # Check last token\n", " last_token_id = newly_generated_tokens[0, -1].item() if newly_generated_tokens.size(1) > 0 else None\n", " print(f\"Last token ID: {last_token_id}\")\n", "\n", " # Check for end conditions\n", " if last_token_id == eos_id:\n", " print(\"Found EOS token, ending generation\")\n", " break\n", "\n", " # Check for code end token\n", " if last_token_id == code_id[1]:\n", " print(\"Found token, executing code\")\n", "\n", " # Extract code from the full text\n", " full_text = self.processing_class.decode(\n", " torch.cat([prompt_ids_batch[i], cumulative_completion_ids[0]], dim=0)\n", " )\n", " code_match = re.search(r'(.*?)', full_text, re.DOTALL)\n", "\n", " if code_match:\n", " code_block = code_match.group(1).strip()\n", "\n", " # Execute code\n", " interpreter_text = self._execute_code(code_block)\n", "\n", " # Format and add interpreter output\n", " formatted_feedback = f\"{self.processing_class.decode(interpreter_id[0])}{interpreter_text}{self.processing_class.decode(interpreter_id[1])}\"\n", " interpreter_ids = self.processing_class(\n", " formatted_feedback,\n", " return_tensors=\"pt\",\n", " add_special_tokens=False\n", " ).input_ids.to(prompt_ids_batch.device)\n", "\n", " # Record positions\n", " interpreter_start_idx = cumulative_completion_ids.size(1)\n", " cumulative_completion_ids = torch.cat([cumulative_completion_ids, interpreter_ids], dim=1)\n", " interpreter_end_idx = cumulative_completion_ids.size(1) - 1\n", " interpreter_positions.append((interpreter_start_idx, interpreter_end_idx))\n", "\n", " print(f\"Added interpreter output: {formatted_feedback}\")\n", "\n", " # Set up for next turn\n", " current_input_id = interpreter_ids\n", " current_attention_mask = torch.ones_like(current_input_id)\n", " # Keep current_kv from previous generation\n", " else:\n", " print(\"No code block found despite token\")\n", " break\n", " else:\n", " # Continue with the newly generated tokens\n", " current_input_id = newly_generated_tokens\n", " current_attention_mask = torch.ones_like(current_input_id)\n", "\n", " # Add to batch results\n", " batch_completion.append(cumulative_completion_ids.squeeze(0))\n", " batch_interpreter_positions.append(interpreter_positions)\n", "\n", " # Pad sequences\n", " if len(batch_completion) > 0:\n", " # Ensure padding_value is a valid integer\n", " padding_value = self.processing_class.pad_token_id\n", " if padding_value is None:\n", " padding_value = 0 # Use 0 as a default if pad_token_id is None\n", "\n", " padded_sequences = torch.nn.utils.rnn.pad_sequence(\n", " batch_completion,\n", " batch_first=True,\n", " padding_value=padding_value\n", " )\n", " else:\n", " padded_sequences = torch.empty((0, 0), dtype=torch.long, device=prompt_ids_batch.device)\n", "\n", " return padded_sequences, batch_interpreter_positions\n", "\n", " # 4. Create test instance\n", " tester = TestReToolTrainer(model, tokenizer, device)\n", "\n", " # 5. Create a test prompt with a complete code block\n", " prompt = \"\"\"Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\"\"\"\n", "\n", " # 6. Run the test\n", " try:\n", " print(\"\\n=== Testing Updated ReTool Implementation ===\\n\")\n", "\n", " # Encode the prompt\n", " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", " attention_mask = torch.ones_like(prompt_ids)\n", "\n", " # Run the generation\n", " completions, positions = tester._retool_generate_with_interpreter(\n", " prompt_ids_batch=prompt_ids,\n", " attention_mask_batch=attention_mask,\n", " eos_id=tokenizer.eos_token_id,\n", " interpreter_id=[interpreter_start_id, interpreter_end_id],\n", " code_id=[code_start_id, code_end_id],\n", " max_turns=3\n", " )\n", "\n", " # Display results\n", " print(\"\\n=== Final Results ===\\n\")\n", " print(\"Generated completion:\")\n", " print(tokenizer.decode(completions[0]))\n", "\n", " print(\"\\nFull text:\")\n", " print(tokenizer.decode(torch.cat([prompt_ids[0], completions[0]])))\n", "\n", " print(\"\\nInterpreter positions:\", positions)\n", "\n", " except Exception as e:\n", " import traceback\n", " print(f\"Error during testing: {e}\")\n", " traceback.print_exc()\n", "\n", "# Run the test\n", "test_updated_retool_implementation()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4_E6Eo7EHC_8", "outputId": "35b195d9-b0ff-4ddf-c216-fba1c83f40e2" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cpu\n", "EOS token ID: 50256\n", "Pad token ID: 50256\n", "Code tokens: 50257, 50258\n", "Interpreter tokens: 50259, 50260\n", "\n", "=== Testing Updated ReTool Implementation ===\n", "\n", "Processing batch item 1/1\n", "\n", "--- Turn 1 ---\n", "Current input: Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "KV cache present: False\n", "Generated: \n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = []\n", "\n", "Last token ID: 198\n", "\n", "--- Turn 2 ---\n", "Current input: \n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = []\n", "\n", "KV cache present: True\n", "Generated: \n", " a, b = b, a + b\n", "\n", " return result\n", "\n", "print(fibon\n", "Last token ID: 261\n", "\n", "--- Turn 3 ---\n", "Current input: \n", " a, b = b, a + b\n", "\n", " return result\n", "\n", "print(fibon\n", "KV cache present: True\n", "Generated: acci(5))\n", "\n", "So the first two methods are all the same, the last one is a little different, and the second one is the\n", "Last token ID: 262\n", "\n", "=== Final Results ===\n", "\n", "Generated completion:\n", "\n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = []\n", "\n", " a, b = b, a + b\n", "\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "So the first two methods are all the same, the last one is a little different, and the second one is the\n", "\n", "Full text:\n", "Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = []\n", "\n", " a, b = b, a + b\n", "\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "So the first two methods are all the same, the last one is a little different, and the second one is the\n", "\n", "Interpreter positions: [[]]\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "Z0EHHkP3J7Ox" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "def test_retool_core_functionality():\n", " # 1. Create minimal model and tokenizer\n", " from transformers import AutoModelForCausalLM, AutoTokenizer\n", " import torch\n", " import transformers\n", " import re\n", "\n", " # Use a small model for testing\n", " model_name = \"gpt2-medium\"\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Check if CUDA is available\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"Using device: {device}\")\n", "\n", " # Load model directly to the selected device\n", " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", "\n", " # 2. Add special tokens to the tokenizer\n", " special_tokens = {\n", " 'additional_special_tokens': ['', '', '', '']\n", " }\n", " tokenizer.add_special_tokens(special_tokens)\n", " model.resize_token_embeddings(len(tokenizer))\n", "\n", " # Get token IDs for special tokens\n", " code_start_id = tokenizer.convert_tokens_to_ids('')\n", " code_end_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_start_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_end_id = tokenizer.convert_tokens_to_ids('')\n", "\n", " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n", " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n", "\n", " # 3. Create a simplified implementation of _retool_generate_with_interpreter\n", " def simplified_generate_with_interpreter(model, tokenizer, prompt_text, device):\n", " \"\"\"Simplified version focusing just on the core functionality\"\"\"\n", " # Step 1: Tokenize the prompt\n", " prompt_ids = tokenizer.encode(prompt_text, return_tensors=\"pt\").to(device)\n", "\n", " # Initialize tracking variables\n", " cumulative_completion_ids = torch.empty((1, 0), dtype=torch.long, device=device)\n", " interpreter_positions = []\n", "\n", " # Step 2: Extract a code block and execute it\n", " full_text = prompt_text\n", " code_match = re.search(r'(.*?)', full_text, re.DOTALL)\n", "\n", " if code_match:\n", " code_block = code_match.group(1).strip()\n", " print(f\"Found code block: {code_block}\")\n", "\n", " # Mock code execution\n", " interpreter_output = \"0 1 1 2 3\"\n", " print(f\"Code execution result: {interpreter_output}\")\n", "\n", " # Format interpreter feedback\n", " interpreter_text = f\"{interpreter_output}\"\n", " interpreter_ids = tokenizer.encode(\n", " interpreter_text,\n", " return_tensors=\"pt\",\n", " add_special_tokens=False\n", " ).to(device)\n", "\n", " # Step 3: Generate a continuation after the interpreter output\n", " # First, create a sequence with prompt + interpreter output\n", " combined_input = prompt_text + interpreter_text\n", " combined_ids = tokenizer.encode(combined_input, return_tensors=\"pt\").to(device)\n", "\n", " # Generate continuation\n", " with torch.no_grad():\n", " continuation_outputs = model.generate(\n", " input_ids=combined_ids,\n", " max_new_tokens=50,\n", " do_sample=True,\n", " temperature=0.7,\n", " top_p=0.9,\n", " pad_token_id=tokenizer.pad_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", " return_dict_in_generate=True,\n", " cache_implementation= 'offloaded',\n", " )\n", "\n", " # Extract only the newly generated tokens\n", " continuation_tokens = continuation_outputs.sequences[:, combined_ids.size(1):]\n", "\n", " # Combine everything for the final result\n", " # The completion consists of: interpreter output + continuation\n", " cumulative_completion_ids = torch.cat([interpreter_ids, continuation_tokens], dim=1)\n", "\n", " # Record the interpreter position\n", " interpreter_positions.append((0, interpreter_ids.size(1) - 1))\n", "\n", " print(f\"Generated continuation: {tokenizer.decode(continuation_tokens[0])}\")\n", " else:\n", " print(\"No code block found in the prompt.\")\n", "\n", " return cumulative_completion_ids, interpreter_positions\n", "\n", " # 4. Test with a prompt that has a complete code block\n", " prompt = \"\"\"Let's calculate Fibonacci numbers in Python:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\"\"\"\n", "\n", " # 5. Run our simplified test\n", " try:\n", " print(\"\\n--- Testing Core Functionality ---\")\n", " completion, positions = simplified_generate_with_interpreter(model, tokenizer, prompt, device)\n", "\n", " print(\"\\n--- Final Results ---\")\n", " print(\"Completion:\")\n", " print(tokenizer.decode(completion[0]))\n", " print(\"\\nInterpreter positions:\", positions)\n", "\n", " # 6. Now also test ReToolTrainer to verify core code execution functionality\n", " print(\"\\n--- Testing ReToolTrainer with Direct Injection ---\")\n", "\n", " # Setup trainer\n", " trainer = ReToolTrainer(\n", " model=model,\n", " processing_class=tokenizer,\n", " args=transformers.TrainingArguments(\n", " output_dir=\"./test_output\",\n", " per_device_train_batch_size=1,\n", " ),\n", " train_dataset=None,\n", " eval_dataset=None,\n", " max_turns=3,\n", " interpreter_id=[interpreter_start_id, interpreter_end_id],\n", " code_id=[code_start_id, code_end_id],\n", " eos_id=tokenizer.eos_token_id\n", " )\n", "\n", " # Override the _execute_code method\n", " def mock_execute_code(self, code_block):\n", " print(f\"Mock executing code: {code_block}\")\n", " return \"0 1 1 2 3\"\n", "\n", " original_execute_code = trainer._execute_code\n", " trainer._execute_code = mock_execute_code.__get__(trainer, ReToolTrainer)\n", "\n", " # Create a sequence that has a prompt and ends with \n", " # This is to simulate that the model has generated a complete code block\n", " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", " attention_mask = torch.ones_like(prompt_ids)\n", "\n", " # Directly inject a simulated generation with a code block\n", " # Create a custom testing function\n", " def test_execute_code_and_continue(self, prompt_ids, attention_mask):\n", " \"\"\"Test just the code execution and continuation part\"\"\"\n", " print(\"Testing code execution and continuation...\")\n", " device = next(self.model.parameters()).device\n", " prompt_ids = prompt_ids.to(device)\n", " attention_mask = attention_mask.to(device)\n", "\n", " # Extract code from the prompt\n", " full_text = self.processing_class.decode(prompt_ids[0])\n", " code_match = re.search(r'(.*?)', full_text, re.DOTALL)\n", "\n", " if not code_match:\n", " print(\"No code block found in the prompt!\")\n", " return None, []\n", "\n", " code_block = code_match.group(1).strip()\n", " print(f\"Executing code block: {code_block}\")\n", "\n", " # Execute the code\n", " interpreter_text = self._execute_code(code_block)\n", "\n", " # Format and tokenize the interpreter output\n", " formatted_feedback = f\"{self.processing_class.decode(self.interpreter_id[0])}{interpreter_text}{self.processing_class.decode(self.interpreter_id[1])}\"\n", " interpreter_ids = self.processing_class(\n", " formatted_feedback,\n", " return_tensors=\"pt\",\n", " add_special_tokens=False\n", " ).input_ids.to(device)\n", "\n", " # Record position (relative to completion only)\n", " interpreter_positions = [(0, interpreter_ids.size(1) - 1)]\n", "\n", " # Combine prompt with interpreter output for continuation\n", " combined_ids = torch.cat([prompt_ids, interpreter_ids], dim=1)\n", " combined_mask = torch.ones_like(combined_ids)\n", "\n", " # Generate continuation\n", " continuation_outputs = self.model.generate(\n", " input_ids=combined_ids,\n", " attention_mask=combined_mask,\n", " max_new_tokens=50,\n", " do_sample=True,\n", " temperature=0.7,\n", " top_p=0.9,\n", " pad_token_id=self.processing_class.pad_token_id,\n", " eos_token_id=self.eos_id,\n", " return_dict_in_generate=True,\n", " cache_implementation= 'offloaded',\n", " )\n", "\n", " # Extract only the newly generated continuation\n", " continuation_tokens = continuation_outputs.sequences[:, combined_ids.size(1):]\n", "\n", " # Full completion is: interpreter output + continuation\n", " completion = torch.cat([interpreter_ids, continuation_tokens], dim=1)\n", "\n", " return completion, interpreter_positions\n", "\n", " # Add the test method to the trainer\n", " trainer.test_execute_code_and_continue = test_execute_code_and_continue.__get__(trainer, ReToolTrainer)\n", "\n", " # Run the test\n", " completion, positions = trainer.test_execute_code_and_continue(prompt_ids, attention_mask)\n", "\n", " print(\"\\n--- Trainer Test Results ---\")\n", " if completion is not None:\n", " print(\"Completion:\")\n", " print(tokenizer.decode(completion[0]))\n", " print(\"\\nInterpreter positions:\", positions)\n", "\n", " except Exception as e:\n", " import traceback\n", " print(f\"Error during testing: {e}\")\n", " traceback.print_exc()\n", " finally:\n", " # Restore original method if needed\n", " if 'trainer' in locals() and 'original_execute_code' in locals():\n", " trainer._execute_code = original_execute_code\n", "\n", "# Run the test\n", "test_retool_core_functionality()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "MumcLxASaBkj", "outputId": "21dd5d29-f402-4837-be86-08cee4b9d7a2" }, "execution_count": 25, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Code tokens: 50257, 50258\n", "Interpreter tokens: 50259, 50260\n", "\n", "--- Testing Core Functionality ---\n", "Found code block: def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "Code execution result: 0 1 1 2 3\n", "Generated continuation: 0 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48\n", "\n", "--- Final Results ---\n", "Completion:\n", "0 1 1 2 30 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48\n", "\n", "Interpreter positions: [(0, 6)]\n", "\n", "--- Testing ReToolTrainer with Direct Injection ---\n", "Testing code execution and continuation...\n", "Executing code block: def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "Mock executing code: def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/tmp/ipython-input-20-2039368761.py:57: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `ReToolTrainer.__init__`. Use `processing_class` instead.\n", " super().__init__(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", "--- Trainer Test Results ---\n", "Completion:\n", "0 1 1 2 31 1 1 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3\n", "\n", "Interpreter positions: [(0, 6)]\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "2KBJZTXOaCA3" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache\n", "\n", "def test_direct_kv_cache_usage():\n", " # 1. Setup model, tokenizer, and device\n", " from transformers import AutoModelForCausalLM, AutoTokenizer\n", " import torch\n", "\n", " # Use a model that fits in memory\n", " model_name = \"gpt2-medium\"\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Check device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"Using device: {device}\")\n", "\n", " # Load model to device\n", " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", "\n", " # 2. Manual token-by-token generation with KV caching\n", " def generate_with_manual_kv_cache(input_ids, num_tokens=20):\n", " \"\"\"Generate tokens one by one with manual KV cache management\"\"\"\n", " current_ids = input_ids.clone()\n", " past_key_values = None\n", "\n", " generated_tokens = []\n", "\n", " for _ in range(num_tokens):\n", " # Forward pass with past_key_values\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=current_ids if past_key_values is None else current_ids[:, -1:],\n", " #past_key_values=past_key_values,\n", " past_key_values= DynamicCache.from_legacy_cache(past_key_values),\n", " use_cache=True\n", " )\n", "\n", " # Get logits for the next token (last position)\n", " next_token_logits = outputs.logits[:, -1, :]\n", "\n", " # Sample from the distribution\n", " probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " # Add to generated tokens\n", " generated_tokens.append(next_token.item())\n", "\n", " # Update current_ids for next iteration\n", " current_ids = torch.cat([current_ids, next_token], dim=1)\n", "\n", " # Update past_key_values\n", " past_key_values = outputs.past_key_values\n", " #print('after generation, past_key_values ', past_key_values)\n", "\n", " return generated_tokens, past_key_values\n", "\n", " # 3. Run multi-turn generation with manual KV cache\n", " def run_manual_multi_turn_generation():\n", " # Start with a prompt\n", " prompt = \"Once upon a time, in a magical forest, there lived a\"\n", "\n", " # Tokenize the prompt\n", " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", "\n", " # Initialize tracking\n", " full_text = prompt\n", " current_ids = prompt_ids\n", " past_kv = None\n", "\n", " # Generate in multiple turns\n", " for turn_idx in range(3): # 3 turns\n", " print(f\"\\n==== Turn {turn_idx + 1} ====\")\n", " print(f\"Current input: {tokenizer.decode(current_ids[0])}\")\n", " print(f\"KV cache present: {past_kv is not None}\")\n", "\n", " # Pause for inspection\n", " input(\"Press Enter to generate next part...\")\n", "\n", " # Generate new tokens manually\n", " new_token_ids, past_kv = generate_with_manual_kv_cache(current_ids, num_tokens=20)\n", "\n", " # Decode and display new tokens\n", " new_text = tokenizer.decode(new_token_ids)\n", " print(f\"Generated: {new_text}\")\n", "\n", " # Accumulate\n", " full_text += new_text\n", "\n", " # Now inject a custom continuation\n", " custom_text = \" Suddenly, a rainbow appeared in the sky!\"\n", " custom_ids = tokenizer.encode(custom_text, return_tensors=\"pt\").to(device)\n", "\n", " print(f\"\\n==== Injecting custom text: {custom_text} ====\")\n", "\n", " # Update tracking\n", " full_text += custom_text\n", "\n", " # Prepare for next turn - start with the custom text\n", " current_ids = custom_ids\n", " # Keep past_kv from previous generation\n", "\n", " print(f\"Full text so far: {full_text}\")\n", " print(\"-\" * 50)\n", "\n", " print(\"\\n==== Final Story ====\")\n", " print(full_text)\n", "\n", " return full_text\n", "\n", " # 4. Run the test\n", " try:\n", " print(\"\\n=== Testing Manual KV Cache Usage ===\\n\")\n", "\n", " story = run_manual_multi_turn_generation()\n", "\n", " print(\"\\n=== Test Complete ===\")\n", "\n", " except Exception as e:\n", " import traceback\n", " print(f\"Error during testing: {e}\")\n", " traceback.print_exc()\n", "\n", "# Run the test\n", "test_direct_kv_cache_usage()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p2_BiDS0IlPl", "outputId": "f62d0e49-8f27-4c9b-aeac-dbe1fe1def21" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n", "\n", "=== Testing Manual KV Cache Usage ===\n", "\n", "\n", "==== Turn 1 ====\n", "Current input: Once upon a time, in a magical forest, there lived a\n", "KV cache present: False\n", "Press Enter to generate next part...\n", "Generated: wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to\n", "\n", "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n", "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky!\n", "--------------------------------------------------\n", "\n", "==== Turn 2 ====\n", "Current input: Suddenly, a rainbow appeared in the sky!\n", "KV cache present: True\n", "Press Enter to generate next part...\n", "Generated: Although it was a little less than a hundred meters wide, it was enough to cover the entire sky\n", "\n", "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n", "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky!\n", "--------------------------------------------------\n", "\n", "==== Turn 3 ====\n", "Current input: Suddenly, a rainbow appeared in the sky!\n", "KV cache present: True\n", "Press Enter to generate next part...\n", "Generated: A rainbow!\n", "\n", "I thought about it, and after seeing how many people were in the sky\n", "\n", "==== Injecting custom text: Suddenly, a rainbow appeared in the sky! ====\n", "Full text so far: Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky! A rainbow!\n", "\n", "I thought about it, and after seeing how many people were in the sky Suddenly, a rainbow appeared in the sky!\n", "--------------------------------------------------\n", "\n", "==== Final Story ====\n", "Once upon a time, in a magical forest, there lived a wizard and a witch. One day, they were attacked by a flying beast and the wizard fled to Suddenly, a rainbow appeared in the sky! Although it was a little less than a hundred meters wide, it was enough to cover the entire sky Suddenly, a rainbow appeared in the sky! A rainbow!\n", "\n", "I thought about it, and after seeing how many people were in the sky Suddenly, a rainbow appeared in the sky!\n", "\n", "=== Test Complete ===\n" ] } ] }, { "cell_type": "code", "source": [ "def test_retool_with_working_kv_cache():\n", " # 1. Setup model, tokenizer, and device\n", " from transformers import AutoModelForCausalLM, AutoTokenizer\n", " import torch\n", " import re\n", "\n", " # Use a model that fits in memory\n", " model_name = \"gpt2-medium\"\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Check device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"Using device: {device}\")\n", "\n", " # Load model to device\n", " model = AutoModelForCausalLM.from_pretrained(model_name).to(device)\n", "\n", " # 2. Add special tokens\n", " special_tokens = {\n", " 'additional_special_tokens': ['', '', '', '']\n", " }\n", " tokenizer.add_special_tokens(special_tokens)\n", " model.resize_token_embeddings(len(tokenizer))\n", "\n", " # Get token IDs\n", " code_start_id = tokenizer.convert_tokens_to_ids('')\n", " code_end_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_start_id = tokenizer.convert_tokens_to_ids('')\n", " interpreter_end_id = tokenizer.convert_tokens_to_ids('')\n", "\n", " print(f\"EOS token ID: {tokenizer.eos_token_id}\")\n", " print(f\"Code tokens: {code_start_id}, {code_end_id}\")\n", " print(f\"Interpreter tokens: {interpreter_start_id}, {interpreter_end_id}\")\n", "\n", " # 3. Manual token generation with KV caching\n", " def generate_with_manual_kv_cache(input_ids, past_key_values=None, max_tokens=20, stop_ids=None):\n", " \"\"\"Generate tokens with KV cache until a stop token or max_tokens is reached\"\"\"\n", " if stop_ids is None:\n", " stop_ids = [tokenizer.eos_token_id]\n", "\n", " current_ids = input_ids.clone()\n", " generated_tokens = []\n", "\n", " for _ in range(max_tokens):\n", " # Forward pass with past_key_values\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids=current_ids if past_key_values is None else current_ids[:, -1:],\n", " past_key_values=past_key_values,\n", " use_cache=True\n", " )\n", "\n", " # Get logits for the next token\n", " next_token_logits = outputs.logits[:, -1, :]\n", "\n", " # Sample from the distribution\n", " probs = torch.nn.functional.softmax(next_token_logits / 0.7, dim=-1)\n", " next_token = torch.multinomial(probs, num_samples=1)\n", "\n", " # Get the token ID\n", " token_id = next_token.item()\n", "\n", " # Add to generated tokens\n", " generated_tokens.append(token_id)\n", "\n", " # Update current_ids for next iteration\n", " current_ids = torch.cat([current_ids, next_token], dim=1)\n", "\n", " # Update past_key_values\n", " past_key_values = outputs.past_key_values\n", "\n", " # Check if we hit a stop token\n", " if token_id in stop_ids:\n", " break\n", "\n", " # Convert list of token IDs to tensor\n", " result_tensor = torch.tensor([generated_tokens], device=device)\n", " return result_tensor, past_key_values\n", "\n", " # 4. ReTool simulation with working KV cache\n", " def simulate_retool_with_working_kv_cache(prompt, max_turns=3):\n", " \"\"\"Simulate the ReTool process with working KV cache\"\"\"\n", " # Tokenize the prompt\n", " prompt_ids = tokenizer.encode(prompt, return_tensors=\"pt\").to(device)\n", "\n", " # Initialize tracking\n", " full_sequence = prompt_ids.clone()\n", " completion = torch.empty((1, 0), dtype=torch.long, device=device)\n", " interpreter_positions = []\n", "\n", " # Keep the KV cache from previous turns\n", " past_kv = None\n", "\n", " for turn_idx in range(max_turns):\n", " print(f\"\\n==== Turn {turn_idx + 1} ====\")\n", "\n", " # Determine what to generate from\n", " if turn_idx == 0:\n", " # First turn - generate from the prompt\n", " current_input = full_sequence\n", " print(f\"Generating from prompt: {tokenizer.decode(current_input[0])}\")\n", " else:\n", " # Later turns - might be generating from interpreter output\n", " current_input = full_sequence[:, -20:] if full_sequence.size(1) > 20 else full_sequence\n", " print(f\"Generating from: {tokenizer.decode(current_input[0])}\")\n", "\n", " # Generate with manual KV cache\n", " new_tokens, past_kv = generate_with_manual_kv_cache(\n", " current_input,\n", " past_key_values=past_kv,\n", " max_tokens=30,\n", " stop_ids=[tokenizer.eos_token_id, code_end_id]\n", " )\n", "\n", " # Decode and display\n", " new_text = tokenizer.decode(new_tokens[0])\n", " print(f\"Generated: {new_text}\")\n", "\n", " # Update tracking\n", " full_sequence = torch.cat([full_sequence, new_tokens], dim=1)\n", " completion = torch.cat([completion, new_tokens], dim=1)\n", "\n", " # Check for code blocks\n", " full_text = tokenizer.decode(full_sequence[0])\n", " code_blocks = re.findall(r'(.*?)', full_text, re.DOTALL)\n", "\n", " # Pause for inspection\n", " input(\"Press Enter to continue...\")\n", "\n", " if code_blocks and code_end_id in new_tokens[0]:\n", " print(\"\\n==== Found code block! ====\")\n", " # Get the last code block\n", " code_block = code_blocks[-1].strip()\n", " print(f\"Code block: {code_block}\")\n", "\n", " # Mock code execution\n", " print(\"\\n==== Executing code ====\")\n", " interpreter_output = \"0 1 1 2 3\"\n", " print(f\"Execution result: {interpreter_output}\")\n", "\n", " # Format interpreter feedback\n", " interpreter_text = f\"{interpreter_output}\"\n", " interpreter_ids = tokenizer.encode(\n", " interpreter_text,\n", " return_tensors=\"pt\",\n", " add_special_tokens=False\n", " ).to(device)\n", "\n", " # Record positions\n", " start_idx = completion.size(1)\n", " completion = torch.cat([completion, interpreter_ids], dim=1)\n", " end_idx = completion.size(1) - 1\n", " interpreter_positions.append((start_idx, end_idx))\n", "\n", " # Add to full sequence\n", " full_sequence = torch.cat([full_sequence, interpreter_ids], dim=1)\n", " print(f\"Added interpreter output: {interpreter_text}\")\n", "\n", " # We're still using the same past_kv for the next turn\n", " # The next input will be the interpreter output\n", " elif tokenizer.eos_token_id in new_tokens[0]:\n", " print(\"Found EOS token, ending generation\")\n", " break\n", "\n", " return completion, interpreter_positions\n", "\n", " # 5. Test with a prompt containing a code block\n", " prompt = \"\"\"Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\"\"\"\n", "\n", " # 6. Run the test\n", " try:\n", " print(\"\\n=== Testing ReTool with Working KV Cache ===\\n\")\n", "\n", " completion, positions = simulate_retool_with_working_kv_cache(prompt)\n", "\n", " print(\"\\n=== Final Results ===\\n\")\n", " print(\"Generated completion:\")\n", " print(tokenizer.decode(completion[0]))\n", "\n", " print(\"\\nFull text:\")\n", " print(tokenizer.decode(torch.cat([tokenizer.encode(prompt, return_tensors=\"pt\")[0].to(device), completion[0]])))\n", "\n", " print(\"\\nInterpreter positions:\", positions)\n", "\n", " except Exception as e:\n", " import traceback\n", " print(f\"Error during testing: {e}\")\n", " traceback.print_exc()\n", "\n", "# Run the test\n", "test_retool_with_working_kv_cache()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "T6_ob3S4M5mn", "outputId": "e5f42a03-c49a-403f-d27b-0ae50ecd095e" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "EOS token ID: 50256\n", "Code tokens: 50257, 50258\n", "Interpreter tokens: 50259, 50260\n", "\n", "=== Testing ReTool with Working KV Cache ===\n", "\n", "\n", "==== Turn 1 ====\n", "Generating from prompt: Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "Generated: \n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = [0,\n", "Press Enter to continue...\n", "\n", "==== Turn 2 ====\n", "Generating from: a, b = 0, 1\n", "\n", " result = [0,\n", "Generated: 0, 0, 1]\n", "\n", " a, b = b, a + b\n", "\n", "\n", "ret = [0,\n", "Press Enter to continue...\n", "\n", "==== Turn 3 ====\n", "Generating from: a, b = b, a + b\n", "\n", "\n", "ret = [0,\n", "Generated: 1, 1, 1]\n", "\n", "for i,j in enumerate(n, fibonacci(n-1, 1-f\n", "Press Enter to continue...\n", "\n", "=== Final Results ===\n", "\n", "Generated completion:\n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = [0, 0, 0, 1]\n", "\n", " a, b = b, a + b\n", "\n", "\n", "ret = [0, 1, 1, 1]\n", "\n", "for i,j in enumerate(n, fibonacci(n-1, 1-f\n", "\n", "Full text:\n", "Let me solve this problem with code:\n", "\n", "\n", "def fibonacci(n):\n", " a, b = 0, 1\n", " result = []\n", " for _ in range(n):\n", " result.append(a)\n", " a, b = b, a + b\n", " return result\n", "\n", "print(fibonacci(5))\n", "\n", "def fibonacci(n):\n", "\n", " a, b = 0, 1\n", "\n", " result = [0, 0, 0, 1]\n", "\n", " a, b = b, a + b\n", "\n", "\n", "ret = [0, 1, 1, 1]\n", "\n", "for i,j in enumerate(n, fibonacci(n-1, 1-f\n", "\n", "Interpreter positions: []\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "YFIXEa5fM5px" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "FjaszXJOIlVz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "xgjX6_xZaCDQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "iTGXE8lRaCF4" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [], "metadata": { "id": "oM5BSZHEaCIx" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "7d252539" }, "source": [ "**1. Clear CUDA Cache:**\n", "\n", "This is often the first thing to try when you get a CUDA OOM error." ] }, { "cell_type": "code", "source": [], "metadata": { "id": "YhKSjnxiaBCb" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f793cb16", "outputId": "3b5b2b99-2e9b-44a2-88df-7293e51de014" }, "source": [ "import torch\n", "\n", "if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " print(\"CUDA cache cleared!\")\n", "else:\n", " print(\"CUDA not available, no cache to clear.\")" ], "execution_count": 18, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "CUDA cache cleared!\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "d25e30fe" }, "source": [ "**2. Delete Large Variables and Run Garbage Collection:**\n", "\n", "Identify variables holding large objects (like models, tensors, dataframes) that you don't need anymore and delete them. Then explicitly run garbage collection." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "02474dce", "outputId": "80223089-31f7-485f-8490-aad00d97277a" }, "source": [ "# Example: if you have a large model or tensor named 'model' or 'data'\n", "# del model\n", "# del data\n", "\n", "import gc\n", "gc.collect()\n", "\n", "print(\"Garbage collection complete.\")" ], "execution_count": 19, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Garbage collection complete.\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "105cefce" }, "source": [ "**3. Restart Runtime:**\n", "\n", "If the above steps don't work, restarting the runtime is the most drastic but often most effective way to clear all memory. Go to the Colab menu: `Runtime` -> `Restart runtime`." ] } ] }