abhinavv3's picture
Fixed some issues and bugs .Finished trail training succesfully
0cac660

�� # >��� Memorizing Transformer with Grouped Query Attention An extended GPT-2-style large language model (LLM) that implements core components from the research paper  Memorizing Transformers (Wu et al., 2022). This project incorporates Grouped Query Attention (GQA), KNN-based memory retrieval, XL-style attention, and Rotary Positional Encoding (RoPE). The training pipeline supports distributed training, data parallelism, and sharded dataset streaming. --- ## =�,� Key Features - ' Grouped Query Attention: Efficient query representation by grouping multiple attention heads for shared K/V access - ' KNN-based Memory: Long-term memory retrieval from past activations using a learned KNN mechanism - ' XL-style Attention: Recurrence-based memory layers adapted for KNN and grouped attention logic - ' Rotary Positional Encoding: More efficient and generalizable positional representation than vanilla sin-cos encoding - ' Sharded Dataset Loader: Handles large datasets with sharding and supports data parallelism via PyTorch DDP - ' Custom Memory Clearing Logic: Memory reset and lifespan mechanisms tuned for stability and performance during training - ' Mixed Precision & DDP Training: Efficient large-scale training using torch.autocast and torchrun --- ## =��� Project Structure bash MEM_TRANSFORMER/ %%% configs/ % %%% config.json # Model + training hyperparameters % %%% data/ % %%% edu_fineweb/ # Token-sharded training data % % %%% train_000001.npy % % %%% train_000002.npy % % %%% test_000001.npy % %%% hellaswag/ % % %%% hellaswag_val.jsonl % %%% fineweb.py # Sharding logic with memory-aligned sequence control % %%% model_core/ % %%% __init__.py % %%% attention.py # Grouped Query Attention, KNN & XL attention logic.Rotary Positional Encoding implementation % %%% model.py # Transformer model with memory and RoPE support % %%% dataloader.py # Memory-aware DataLoader % %%% training.py # train_memgpt function % %%% scripts/ % %%% train.py # Training script (DDP-compatible) % %%% evaluate.py # Evaluation on benchmarks % %%% generate.py # Text generation from trained model % %%% evaluation/ % %%% __init__.py % %%% hellaswag.py # HellaSwag data loader % %%% val_hellaswag.py # Evaluation logic with loss-based scoring % %%% logs/ % %%% log.txt # Training logs % %%% model_*.pt # Checkpoints % %%% .gitignore %%% README.md %%% requirements.txt ## �&� Configuration Edit the config file at configs/config.json to adjust model and training hyperparameters: json { "model": { "block_size": 1024, "vocab_size": 50304, "n_layer": 12, "n_head": 12, "n_embd": 768, "n_kv_head": 4, "max_knn_memories": 81920 }, "training": { "max_steps": 19073, "log_dir": "log", "total_batch_size": 2048, "B": 64, "T": 1024, "max_lr": 0.0006, "min_lr": 0.00006, "warmup_steps": 715, "weight_decay": 0.1, "learning_rate": 0.0006 } } ## =؀� Training �%� Single-GPU Training bash python scripts/train.py �%� Distributed Training (Multi-GPU with DDP) bash torchrun --nproc_per_node=NUM_GPUS scripts/train.py Replace NUM_GPUS with the number of GPUs available. ##=��� Evaluation Evaluate on the HellaSwag benchmark =��� Evaluation Evaluate on the HellaSwag benchmark: python scripts/evaluate.py Make sure the file data/hellaswag/hellaswag_val.jsonl is present. The evaluation uses completion scoring based on masked loss comparisons across candidate endings. >��� Attention Mechanism Notes >��� Grouped Query Attention (GQA) n_head query heads n_kv_head shared key/value heads Query heads are grouped and averaged before memory lookup More efficient than per-head K/V for large models >��� KNN Memory Integration A maximum memory buffer of 81920 tokens (max_knn_memories) Query vectors are projected and grouped for efficient KNN search Careful shape transformations ensure fast grouped matching >��� XL-style Attention + Memory Clearing Recurrence with cached memory states Implements custom memory clearing to avoid stale token influence Helps stability in long training runs =ء� Positional Encoding Rotary Positional Encoding (RoPE) replaces standard sin/cos RoPE improves generalization over longer contexts Implemented in model_core/rotary.py >��� Dataloader & Dataset Handling Sharded training data using .npy files Matching stride and memory alignment logic Optimized for DDP compatibility and large-scale throughput Code in model_core/dataloader.py and data/fineweb.py =��� Requirements Install dependencies: bash pip install -r requirements.txt Ensure PyTorch and CUDA versions match your GPU setup.