|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- abhinavv3/edu_fineweb10B_sharded_50shards |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
tags: |
|
- text-generation |
|
- transformer |
|
--- |
|
# π§ GPT with Modified Memorizing Transformer |
|
|
|
An extended GPT-style 118m param model that integrates the key ideas from **"Memorizing Transformers" (Wu et al., 2022)** with practical enhancements like Grouped Query Attention, KNN-based memory lookup, RoPE, and XL-style memory recurrence. |
|
|
|
This model is designed for scalable training, long-context understanding, and efficient memory usage. |
|
|
|
--- |
|
|
|
|
|
**Key Modifications from the Original Paper:** |
|
|
|
1) Replaced the default positional encoding with Rotary Positional Embeddings (RoPE) , |
|
2) Altered the attention mechanism to use Grouped Query Attention , |
|
3) Customized the DataLoader to support sharded datasets and data parallelism , |
|
4) Implemented Mixed Precision Training along with Distributed Data Parallel (DDP) support , |
|
5) Tweaked several training and model hyperparameters for better adaptability . |
|
|
|
## π¬ Key Features |
|
|
|
- β
**Grouped Query Attention (GQA)** β Groups query heads to share key/value heads, saving memory and speeding up attention |
|
- β
**KNN Memory** β A learnable mechanism to retrieve past activations via nearest-neighbor search |
|
- β
**XL-style Attention** β Adds recurrence to the attention stack, improving long-sequence learning |
|
- β
**Rotary Positional Encoding (RoPE)** β Replaces standard sin-cos encoding for better extrapolation |
|
- β
**Memory Lifespan & Clearing** β Custom mechanisms to manage token memory duration |
|
- β
**Sharded Dataset Loader** β Efficient `.npy`-based streaming for large datasets |
|
- β
**Mixed Precision + DDP Training** β Scalable multi-GPU support using `torchrun` and `torch.autocast` |
|
|
|
--- |
|
|
|
## π Project Structure |
|
|
|
```bash |
|
MEM_TRANSFORMER/ |
|
βββ configs/ |
|
β βββ config.json # Model + training hyperparameters |
|
β |
|
βββ data/ |
|
β βββ edu_fineweb/ # Token-sharded training data |
|
β β βββ train_000001.npy |
|
β β βββ train_000002.npy |
|
β β βββ test_000001.npy |
|
β βββ hellaswag/ |
|
β β βββ hellaswag_val.jsonl |
|
β βββ fineweb.py # Sharding logic with memory-aligned sequence control |
|
β |
|
βββ model_core/ |
|
β βββ __init__.py |
|
β βββ attention.py # Grouped Query Attention, KNN & XL attention logic.Rotary Positional Encoding implementation |
|
β βββ model.py # Transformer model with memory and RoPE support |
|
β βββ dataloader.py # Memory-aware DataLoader |
|
β βββ training.py # train_memgpt function |
|
β |
|
βββ scripts/ |
|
β βββ train.py # Training script (DDP-compatible) |
|
β βββ evaluate.py # Evaluation on benchmarks |
|
β βββ generate.py # Text generation from trained model |
|
β |
|
βββ evaluation/ |
|
β βββ __init__.py |
|
β βββ hellaswag.py # HellaSwag data loader |
|
β βββ val_hellaswag.py # Evaluation logic with loss-based scoring |
|
β |
|
βββ logs/ |
|
β βββ log.txt # Training logs |
|
β βββ model_*.pt # Checkpoints |
|
β |
|
βββ .gitignore |
|
βββ README.md |
|
βββ requirements.txt |
|
|
|
``` |
|
|
|
--- |
|
|
|
## βοΈ Configuration |
|
|
|
Edit `configs/config.json` to change model or training settings. |
|
|
|
<details> |
|
<summary>Example config</summary> |
|
|
|
```json |
|
{ |
|
"model": { |
|
"block_size": 1024, |
|
"vocab_size": 50304, |
|
"n_layer": 12, |
|
"n_head": 12, |
|
"n_embd": 768, |
|
"n_kv_head": 4, |
|
"max_knn_memories": 81920 |
|
}, |
|
"training": { |
|
"max_steps": 19073, |
|
"log_dir": "log", |
|
"total_batch_size": 2048, |
|
"B": 64, |
|
"T": 1024, |
|
"max_lr": 0.0006, |
|
"min_lr": 0.00006, |
|
"warmup_steps": 715, |
|
"weight_decay": 0.1, |
|
"learning_rate": 0.0006 |
|
} |
|
} |
|
``` |
|
</details> |
|
π Training |
|
βΆοΈ Single GPU:python scripts/train.py |
|
π Multi-GPU DDP:torchrun --nproc_per_node=NUM_GPUS scripts/train.py |
|
|
|
|
|
π Evaluation |
|
Evaluate on the HellaSwag benchmark: |
|
```bash |
|
python scripts/evaluate.py |
|
``` |
|
|
|
Requires: |
|
|
|
data/hellaswag/hellaswag_val.jsonl |
|
|
|
Model checkpoint(s) in logs/ |
|
|
|
Scoring is based on masked token loss across multiple choice completions |
|
|
|
π§ Attention Mechanism Deep Dive |
|
<details> <summary>Grouped Query Attention (GQA)</summary> |
|
n_head = total query heads |
|
|
|
n_kv_head = shared key/value heads |
|
|
|
Reduces compute overhead for large models by grouping query heads to reuse K/V |
|
|
|
</details> <details> <summary>KNN Memory Retrieval</summary> |
|
Maintains memory of past key vectors (max: 81920 tokens) |
|
|
|
Fast KNN lookup with grouped projections |
|
|
|
Integrated into attention flow using model_core/attention.py |
|
|
|
</details> <details> <summary>XL-style Recurrence</summary> |
|
Recurrence between attention blocks |
|
|
|
Memory cache updated at each step |
|
|
|
Custom clearing logic helps avoid stale activations |
|
|
|
</details> <details> <summary>Rotary Positional Encoding (RoPE)</summary> |
|
Replaces standard sinusoidal encoding |
|
|
|
Better generalization on long contexts |
|
|
|
Found in model_core/attention.py |
|
|
|
</details> |
|
|
|
π§© Data Handling |
|
Training data is sharded .npy files |
|
|
|
Matching stride/memory length logic |
|
|
|
DDP-compatible DataLoader |
|
|
|
π¦ Install Dependencies |
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
|
|
Ensure that PyTorch and CUDA versions match your local GPU. |
|
|
|
π Reference |
|
Wu et al., Memorizing Transformers, NeurIPS 2022 |
|
[Paper link](https://arxiv.org/abs/2203.08913) |
|
|
|
π‘ Future Work |
|
Add LoRA support |
|
|
|
Integrate with Hugging Face transformers API |
|
|
|
Add benchmarking on other datasets (e.g. LAMBADA, PIQA) |
|
|
|
Built with β€οΈ by abhinavv3 |