abhinavv3's picture
Fixed some issues and bugs .Finished trail training succesfully
0cac660
��
# >��� Memorizing Transformer with Grouped Query Attention
An extended GPT-2-style large language model (LLM) that implements core components from the research paper ** Memorizing Transformers (Wu et al., 2022)**.
This project incorporates Grouped Query Attention (GQA), KNN-based memory retrieval, XL-style attention, and Rotary Positional Encoding (RoPE).
The training pipeline supports distributed training, data parallelism, and sharded dataset streaming.
---
## =�,� Key Features
- ' **Grouped Query Attention**: Efficient query representation by grouping multiple attention heads for shared K/V access
- ' **KNN-based Memory**: Long-term memory retrieval from past activations using a learned KNN mechanism
- ' **XL-style Attention**: Recurrence-based memory layers adapted for KNN and grouped attention logic
- ' **Rotary Positional Encoding**: More efficient and generalizable positional representation than vanilla sin-cos encoding
- ' **Sharded Dataset Loader**: Handles large datasets with sharding and supports data parallelism via PyTorch DDP
- ' **Custom Memory Clearing Logic**: Memory reset and lifespan mechanisms tuned for stability and performance during training
- ' **Mixed Precision & DDP Training**: Efficient large-scale training using `torch.autocast` and `torchrun`
---
## =��� Project Structure
```bash
MEM_TRANSFORMER/
%%% configs/
% %%% config.json # Model + training hyperparameters
%
%%% data/
% %%% edu_fineweb/ # Token-sharded training data
% % %%% train_000001.npy
% % %%% train_000002.npy
% % %%% test_000001.npy
% %%% hellaswag/
% % %%% hellaswag_val.jsonl
% %%% fineweb.py # Sharding logic with memory-aligned sequence control
%
%%% model_core/
% %%% __init__.py
% %%% attention.py # Grouped Query Attention, KNN & XL attention logic.Rotary Positional Encoding implementation
% %%% model.py # Transformer model with memory and RoPE support
% %%% dataloader.py # Memory-aware DataLoader
% %%% training.py # train_memgpt function
%
%%% scripts/
% %%% train.py # Training script (DDP-compatible)
% %%% evaluate.py # Evaluation on benchmarks
% %%% generate.py # Text generation from trained model
%
%%% evaluation/
% %%% __init__.py
% %%% hellaswag.py # HellaSwag data loader
% %%% val_hellaswag.py # Evaluation logic with loss-based scoring
%
%%% logs/
% %%% log.txt # Training logs
% %%% model_*.pt # Checkpoints
%
%%% .gitignore
%%% README.md
%%% requirements.txt
```
## �&� Configuration
Edit the config file at configs/config.json to adjust model and training hyperparameters:
```json
{
"model": {
"block_size": 1024,
"vocab_size": 50304,
"n_layer": 12,
"n_head": 12,
"n_embd": 768,
"n_kv_head": 4,
"max_knn_memories": 81920
},
"training": {
"max_steps": 19073,
"log_dir": "log",
"total_batch_size": 2048,
"B": 64,
"T": 1024,
"max_lr": 0.0006,
"min_lr": 0.00006,
"warmup_steps": 715,
"weight_decay": 0.1,
"learning_rate": 0.0006
}
}
```
```
```
## =؀� Training
�%� Single-GPU Training
```bash
python scripts/train.py
```
�%� Distributed Training (Multi-GPU with DDP)
```bash
torchrun --nproc_per_node=NUM_GPUS scripts/train.py
```
Replace NUM_GPUS with the number of GPUs available.
```
##=��� Evaluation
Evaluate on the HellaSwag benchmark
```
=��� Evaluation
Evaluate on the HellaSwag benchmark:
python scripts/evaluate.py
Make sure the file data/hellaswag/hellaswag_val.jsonl is present.
The evaluation uses completion scoring based on masked loss comparisons across candidate endings.
>��� Attention Mechanism Notes
>��� Grouped Query Attention (GQA)
n_head query heads
n_kv_head shared key/value heads
Query heads are grouped and averaged before memory lookup
More efficient than per-head K/V for large models
>��� KNN Memory Integration
A maximum memory buffer of 81920 tokens (max_knn_memories)
Query vectors are projected and grouped for efficient KNN search
Careful shape transformations ensure fast grouped matching
>��� XL-style Attention + Memory Clearing
Recurrence with cached memory states
Implements custom memory clearing to avoid stale token influence
Helps stability in long training runs
=ء� Positional Encoding
Rotary Positional Encoding (RoPE) replaces standard sin/cos
RoPE improves generalization over longer contexts
Implemented in model_core/rotary.py
>��� Dataloader & Dataset Handling
Sharded training data using .npy files
Matching stride and memory alignment logic
Optimized for DDP compatibility and large-scale throughput
Code in model_core/dataloader.py and data/fineweb.py
=��� Requirements
Install dependencies:
```bash
pip install -r requirements.txt
```
Ensure PyTorch and CUDA versions match your GPU setup.