File size: 7,344 Bytes
dd850a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# Claude Code Instructions - Token Attention Visualizer
## Project Overview
You are helping to build a Token Attention Visualizer - a web-based tool that visualizes attention weights in Large Language Models (LLMs) during text generation. The tool shows how input tokens influence the generation of output tokens through interactive visualizations.
## Core Functionality
1. Accept a text prompt and generate tokens using a Llama model
2. Extract and process attention matrices from the model
3. Create an interactive visualization showing token relationships
4. Allow users to click tokens to filter connections
5. Provide step-by-step navigation through the generation process
## Tech Stack
- **Backend**: FastAPI
- **Frontend**: Gradio (for easy Hugging Face Spaces deployment)
- **Visualization**: Plotly (interactive graphs)
- **ML**: Transformers, PyTorch
- **Models**: Llama models (1B-3B range)
## Project Structure
```
token-attention-viz/
βββ app.py # Main Gradio application
βββ api/
β βββ __init__.py
β βββ server.py # FastAPI endpoints (optional)
β βββ models.py # Pydantic models
βββ core/
β βββ __init__.py
β βββ model_handler.py # Model loading and generation
β βββ attention.py # Attention processing
β βββ cache.py # Caching logic
βββ visualization/
β βββ __init__.py
β βββ plotly_viz.py # Plotly visualization
β βββ utils.py # Token cleaning utilities
βββ requirements.txt
βββ config.py # Configuration settings
```
## Implementation Guidelines
### Critical Code to Preserve from Original Implementation
1. **Model Loading Logic**:
- Device and dtype detection based on GPU capability
- Pad token handling for models without it
- Error handling for model loading
2. **Attention Extraction** :
- BOS token removal from visualization
- EOS token handling
- Attention matrix extraction with proper indexing
3. **Token Cleaning Function**:
```python
def clean_label(token):
label = str(token)
label = label.replace('Δ ', ' ')
label = label.replace('β', ' ')
label = label.replace('Δ', '\\n')
label = label.replace('</s>', '[EOS]')
label = label.replace('<unk>', '[UNK]')
label = label.replace('<|begin_of_text|>', '[BOS]')
label = label.replace('<|end_of_text|>', '[EOS]')
label = re.sub(r'<0x[0-9A-Fa-f]{2}>', '', label)
return label.strip() if label.strip() else "[EMPTY]"
```
4. **Attention Processing with Separate Normalization**:
- Layer averaging across heads and layers
- Separate normalization for input and output attention
- Epsilon handling (1e-8) to avoid division by zero
5. **Interactive Features**:
- Token click handling to show specific connections
- Reset selection functionality
- Step-by-step navigation
- "All Connections" view
### Key Implementation Details
#### Model Handler (`core/model_handler.py`)
- Use `unsloth/Llama-3.2-1B-Instruct` as default model
- Implement proper device detection (CUDA if available)
- Use bfloat16 for GPUs with compute capability >= 8.0
- Generate with `output_attentions=True` and `return_dict_in_generate=True`
#### Attention Processing (`core/attention.py`)
- Extract attention for each generation step
- Average across all layers and heads
- Apply separate normalization (input and output attention normalized independently)
- Handle edge cases (first token has no output-to-output attention)
#### Visualization (`visualization/plotly_viz.py`)
- **Layout**:
- Input tokens on left (x=0.1)
- Output tokens on right (x=0.9)
- Use linspace for y-coordinates
- **Connections**:
- Blue lines for inputβoutput attention
- Orange curved lines for outputβoutput attention
- Line thickness proportional to attention weight
- Only show connections above threshold
- **Interactivity**:
- Click on any token to filter connections
- Highlight selected token in yellow
- Show previously generated tokens in pink
- Current generating token in coral
#### Gradio Interface (`app.py`)
- **Input Controls**:
- Text area for prompt
- Slider for max tokens (1-50)
- Slider for attention threshold (0.0-0.2, step 0.001)
- **Visualization Controls**:
- Step slider for navigation
- Reset Selection button
- Show All Connections button
- **Display**:
- Generated text output
- Interactive Plotly graph
### Performance Optimizations
1. **Caching**:
- Cache generated attention matrices by prompt+max_tokens hash
- LRU cache with configurable size (default 10)
- Store processed attention, not raw tensors
2. **Lazy Updates**:
- Only update changed traces when stepping through
- Don't recreate entire plot on threshold change
- Use Plotly's batch_update for multiple changes
3. **Memory Management**:
- Clear raw attention tensors after processing
- Convert to CPU tensors for storage
- Use float32 instead of original dtype for visualization
### Configuration (`config.py`)
```python
DEFAULT_MODEL = "unsloth/Llama-3.2-1B-Instruct"
DEFAULT_PROMPT = "The old wizard walked through the forest"
DEFAULT_MAX_TOKENS = 20
DEFAULT_THRESHOLD = 0.05
MIN_LINE_WIDTH = 0.5
MAX_LINE_WIDTH = 3.0
PLOT_WIDTH = 1000
PLOT_HEIGHT = 600
```
### Deployment Preparation
For Hugging Face Spaces deployment:
1. Create proper `requirements.txt` with pinned versions
2. Add `README.md` with Spaces metadata
3. Ensure model downloads work in Spaces environment
4. Set appropriate memory/GPU requirements
## Testing Instructions
1. **Basic Functionality**:
- Test with default prompt
- Verify attention matrices are extracted correctly
- Check visualization renders properly
2. **Interactive Features**:
- Click on input tokens - should show only their connections to outputs
- Click on output tokens - should show incoming connections
- Reset button should clear selection
- Step slider should navigate through generation
3. **Edge Cases**:
- Empty prompt
- Single token generation
- Very long prompts (>100 tokens)
- High/low threshold values
## Development Workflow
1. Start by implementing the model handler and verify generation works
2. Add attention extraction and processing
3. Create basic visualization without interactivity
4. Add interactive features one by one
5. Implement caching
6. Create Gradio interface
7. Test and optimize performance
8. Prepare for deployment
## Important Notes
- Preserve the token cleaning logic exactly as it handles special tokens
- Keep the BOS token removal logic for cleaner visualization
- Maintain separate normalization (not joint) for attention weights
- Ensure CUDA memory is properly managed to avoid OOM errors
- Test with different model sizes based on available GPU memory
## Common Issues and Solutions
1. **CUDA OOM**: Reduce batch size or use smaller model
2. **Slow Generation**: Enable GPU, use smaller model, or implement streaming
3. **Visualization Lag**: Reduce number of traces, implement virtualization
4. **Cache Misses**: Normalize prompt formatting before hashing
When implementing, prioritize functionality over optimization initially. Get the core visualization working first, then add caching and performance improvements. |