abhinavv3 commited on
Commit
f6d6286
·
0 Parent(s):

Repo before implementing concepts of the paper memorizing transformer

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/hellaswag/hellaswag_val.jsonl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data/edu_fineweb10B
2
+ log/model*
3
+ !log/model_final.pt
Readme.md ADDED
Binary file (7.88 kB). View file
 
configs/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "block_size": 1024,
4
+ "vocab_size": 50304,
5
+ "n_layer": 12,
6
+ "n_head": 12,
7
+ "n_embd": 768
8
+ },
9
+ "training": {
10
+ "max_steps": 19073,
11
+ "log_dir": "log",
12
+ "total_batch_size": 524288,
13
+ "B": 64,
14
+ "T": 1024,
15
+ "max_lr": 0.0006,
16
+ "min_lr": 0.00006,
17
+ "warmup_steps": 715,
18
+ "weight_decay": 0.1,
19
+ "learning_rate": 0.0006
20
+ }
21
+ }
data/__init__.py ADDED
File without changes
data/fineweb.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FineWeb-Edu dataset (for srs pretraining)
3
+ https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu
4
+ Downloads and tokenizes the data and saves data shards to disk.
5
+ Will save shards to the local directory "edu_fineweb10B".
6
+ """
7
+ import os
8
+ import multiprocessing as mp
9
+ import numpy as np
10
+ import tiktoken
11
+ from datasets import load_dataset
12
+ from tqdm import tqdm
13
+
14
+
15
+ local_dir = "edu_fineweb10B"
16
+ remote_name = "sample-10BT"
17
+ shard_size = int(1e8) # 100M tokens per shard, total of 100 shards
18
+
19
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
20
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
21
+ print("Shards will be saved to:",DATA_CACHE_DIR)
22
+
23
+ #dataset download
24
+ fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
25
+
26
+ #tokenizer
27
+ enc = tiktoken.get_encoding("gpt2")
28
+ eot = enc._special_tokens['<|endoftext|>'] # end of text token
29
+
30
+ def tokenize(doc):
31
+ # tokenizes a single document and returns a numpy array of uint16 tokens
32
+ tokens = [eot]
33
+ tokens.extend(enc.encode_ordinary(doc["text"]))
34
+ tokens_np = np.array(tokens)
35
+ assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
36
+ tokens_np_uint16 = tokens_np.astype(np.uint16)
37
+ return tokens_np_uint16
38
+
39
+ def write_datafile(filename, tokens_np):
40
+ np.save(filename, tokens_np)
41
+
42
+ nprocs = max(1, os.cpu_count()//2)
43
+ with mp.Pool(nprocs) as pool:
44
+ shard_index = 0
45
+ # preallocate buffer to hold current shard
46
+ all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
47
+ token_count = 0
48
+ progress_bar = None
49
+ for tokens in pool.imap(tokenize, fw, chunksize=16):
50
+
51
+ # is there enough space in the current shard for the new tokens?
52
+ if token_count + len(tokens) < shard_size:
53
+ # simply append tokens to current shard
54
+ all_tokens_np[token_count:token_count+len(tokens)] = tokens
55
+ token_count += len(tokens)
56
+ # update progress bar
57
+ if progress_bar is None:
58
+ progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
59
+ progress_bar.update(len(tokens))
60
+ else:
61
+ # write the current shard and start a new one
62
+ split = "val" if shard_index == 0 else "train"
63
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
64
+ # split the document into whatever fits in this shard; the remainder goes to next one
65
+ remainder = shard_size - token_count
66
+ progress_bar.update(remainder)
67
+ all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
68
+ write_datafile(filename, all_tokens_np)
69
+ shard_index += 1
70
+ progress_bar = None
71
+ # populate the next shard with the leftovers of the current doc
72
+ all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
73
+ token_count = len(tokens)-remainder
74
+
75
+ # write any remaining tokens as the last shard
76
+ if token_count != 0:
77
+ split = "val" if shard_index == 0 else "train"
78
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
79
+ write_datafile(filename, all_tokens_np[:token_count])
evaluation/__init__.py ADDED
File without changes
evaluation/hellaswag.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads and evaluates HellaSwag in Python.
3
+ https://github.com/rowanz/hellaswag
4
+
5
+ """
6
+ import os
7
+ import json
8
+ import requests
9
+ import tiktoken
10
+ from tqdm import tqdm
11
+ import torch
12
+ from torch.nn import functional as F
13
+
14
+ DATA_DOWNLOADED_PATH = '"data/hellaswag"'
15
+
16
+ def download_file(url:str, fname:str, chunk_size=1024):
17
+ resp = requests.get(url, stream=True)
18
+ total = int(resp.headers.get("content-length", 0 ))
19
+ with open(fname, "wb") as file, tqdm(
20
+ desc = fname,
21
+ total=total,
22
+ unit="iB",
23
+ unit_scale=True,
24
+ unit_divisor=1024
25
+ )as bar:
26
+ for data in resp.iter_content(chunk_size=chunk_size):
27
+ size = file.write(data)
28
+ bar.update(size)
29
+
30
+ hellaswags = {
31
+ "train": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_train.jsonl",
32
+ "val": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl",
33
+ "test": "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_test.jsonl",
34
+ }
35
+
36
+ enc = tiktoken.get_encoding("gpt2")
37
+
38
+ def download(split):
39
+ """Downloads HellaSwag DATA_DOWNLOADED_PATH"""
40
+ os.makedirs(DATA_DOWNLOADED_PATH, exist_ok=True)
41
+ data_url = hellaswags[split]
42
+ data_filename = os.path.join(DATA_DOWNLOADED_PATH, f"hellaswag_{split}.jsonl")
43
+ if not os.path.exists(data_filename):
44
+ print(f"Downloading {data_url} to {data_filename}...")
45
+ download_file(data_url, data_filename)
46
+
47
+ def render_example(example):
48
+ """
49
+ Given the example as a dictionary, render it as three torch tensors:
50
+ - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
51
+ - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
52
+ - label (the index of the correct completion, which we hope has the highest likelihood)
53
+ """
54
+ ctx = example["ctx"]
55
+ label = example["label"]
56
+ endings = example["endings"]
57
+
58
+ # data needed to reproduce this eval on the C size
59
+ data = {
60
+ "label": label,
61
+ "ctx_tokens": None,
62
+ "ending_tokens": [],
63
+ }
64
+
65
+ # gather up all the tokens
66
+ ctx_tokens = enc.encode(ctx)
67
+ data["ctx_tokens"] = ctx_tokens
68
+ tok_rows = []
69
+ mask_rows = []
70
+ for end in endings:
71
+ end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
72
+ tok_rows.append(ctx_tokens + end_tokens)
73
+ mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
74
+ data["ending_tokens"].append(end_tokens)
75
+
76
+ # have to be careful during the collation because the number of tokens in each row can differ
77
+ max_len = max(len(row) for row in tok_rows)
78
+ tokens = torch.zeros((4, max_len), dtype=torch.long)
79
+ mask = torch.zeros((4, max_len), dtype=torch.long)
80
+ for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
81
+ tokens[i, :len(tok_row)] = torch.tensor(tok_row)
82
+ mask[i, :len(mask_row)] = torch.tensor(mask_row)
83
+
84
+ return data, tokens, mask, label
85
+
86
+ def iterate_examples(split):
87
+ # there are 10,042 examples in total in val
88
+ download(split)
89
+ with open(os.path.join(DATA_DOWNLOADED_PATH, f"hellaswag_{split}.jsonl"), "r") as f:
90
+ for line in f:
91
+ example = json.loads(line)
92
+ yield example
93
+
94
+
95
+ def get_most_likely_row(tokens, mask, logits):
96
+ shift_logits = (logits[..., :-1, :]).contiguous() #this will be x for loss calculation
97
+ shift_tokens = (tokens[..., 1:]).contiguous() #this will be y for loss calculation
98
+ shift_mask = (mask[..., 1:]).contiguous() #shifting same as tokens shifted
99
+ flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
100
+ flat_shift_tokens = shift_tokens.view(-1)
101
+ shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
102
+ shift_losses = shift_losses.view(tokens.size(0), -1)
103
+ masked_shift_losses = shift_losses * shift_mask
104
+ sum_loss = masked_shift_losses.sum(dim=1)
105
+ avg_loss = sum_loss / shift_mask.sum(dim=1)
106
+ pred_norm = avg_loss.argmin().item() #taking the index of minimum loss
107
+ return pred_norm
108
+
109
+
110
+
111
+
112
+
113
+
evaluation/val_hellaswag.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..hellaswag import render_example, iterate_examples, get_most_likely_row
3
+ import torch.distributed as dist
4
+ from torch.distributed import init_process_group, destroy_process_group
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+ import os
7
+ from ..ModelGPT2 import GPT,log_file
8
+
9
+ ddp = int(os.environ.get('RANK', -1)) != -1 #will be True if ddp run
10
+ if ddp:
11
+ assert torch.cuda.is_available()
12
+ init_process_group(backend='nccl')
13
+ ddp_rank = int(os.environ['RANK'])
14
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
15
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
16
+ device = f"cuda:{ddp_local_rank}"
17
+ torch.cuda.set_device(device)
18
+ master_process = ddp_rank == 0 #this is the process doing checkpoint,logging,etc
19
+ else:
20
+ ddp_rank = 0
21
+ ddp_local_rank = 0
22
+ ddp_world_size = 1
23
+ master_process = True
24
+ #attempt to autodetect the device
25
+ device = 'cpu'
26
+ if torch.cuda.is_available():
27
+ device = 'cuda'
28
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ device = "mps" #for mac users use apple silicon cpu which allready have gpu.mps is backend for apple silicon
30
+ print(f"Using device: {device}")
31
+ # device = "cpu" #OVERRIDE
32
+
33
+ device_type = "cuda" if device.startswith("cuda") else "cpu"
34
+
35
+ torch.manual_seed(1337)
36
+ if torch.cuda.is_available():
37
+ torch.cuda.manual_seed(1337)
38
+
39
+
40
+ #Creating model by loading the model weights
41
+ checkpoint_path = '../log/model_final.pt'
42
+ if master_process:
43
+ print(f"Loading checkpoint from {checkpoint_path}")
44
+
45
+ checkpoint = torch.load(checkpoint_path, map_location=device)
46
+
47
+ # Extract config and create model
48
+ model_config = checkpoint['config']
49
+ model_config.vocab_size = 50304 #for computational effciency(power of 2)
50
+ model = GPT(model_config)
51
+ # Load model state dict
52
+ model.load_state_dict(checkpoint['model'])
53
+ model = DDP(model, device_ids=[ddp_local_rank])
54
+ model.to(device)
55
+
56
+
57
+ def evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process):
58
+
59
+ num_correct_norm = 0
60
+ num_total = 0
61
+
62
+ for i, example in enumerate(iterate_examples("val")):
63
+ # only process example where i % ddp_world_size ==ddp_rank#this is for proper managemnt of which part is deal by which gpu
64
+ if ddp:
65
+ if i % ddp_world_size != ddp_rank:
66
+ continue
67
+ #rendering example into tokens and labels
68
+ _, tokens, mask, label = render_example(example)
69
+ tokens = tokens.to(device)
70
+ mask = mask.to(device)
71
+ #get the logits
72
+ with torch.no_grad():
73
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
74
+ logits, loss = model(tokens)
75
+ pred_norm = get_most_likely_row(tokens, mask, logits)
76
+ num_total += 1
77
+ num_correct_norm += int(pred_norm == label)
78
+ #reduce the stats accross all process
79
+ if ddp:
80
+ num_total = torch.tensor(num_total, dtype=torch.long, device=device)
81
+ num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
82
+ dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
83
+ dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
84
+ num_total = num_total.item()
85
+ num_correct_norm = num_correct_norm.item()
86
+ acc_norm = num_correct_norm / num_total #accuracy of hellaswag
87
+ if master_process:
88
+ print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
89
+ with open(log_file, "a") as f:
90
+ f.write(f"Final Hellaswag accuracy: {acc_norm:.4f}\n")
91
+
92
+ evaluate_hellaswag(model, device, device_type, ddp, ddp_rank, ddp_world_size, log_file, master_process)
93
+ if ddp:
94
+ destroy_process_group()
log/log.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 0 val 10.9528
model_core/__init__.py ADDED
File without changes
model_core/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (168 Bytes). View file
 
model_core/__pycache__/attention.cpython-311.pyc ADDED
Binary file (2.6 kB). View file
 
model_core/__pycache__/dataloader.cpython-311.pyc ADDED
Binary file (3.87 kB). View file
 
model_core/__pycache__/model.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
model_core/__pycache__/training.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
model_core/attention.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn import functional as F
3
+
4
+
5
+ class CasualSelfAttention(nn.Module):
6
+
7
+ def __init__(self, config):
8
+ super().__init__()
9
+ assert config.n_embd % config.n_head == 0
10
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
11
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
12
+ self.c_proj.NANOGPT_SCALE_INIT = 1
13
+ self.n_head = config.n_head
14
+ self.n_embd = config.n_embd
15
+
16
+ def forward(self, x):
17
+ B, T, C = x.size()
18
+ qkv = self.c_attn(x)
19
+ q, k, v = qkv.split(self.n_embd, dim=2)
20
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
21
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
22
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
23
+
24
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) #flash attention
25
+
26
+ y = y.transpose(1,2).contiguous().view(B, T, C) # (B, T, C)
27
+ y = self.c_proj(y)
28
+ return y
model_core/dataloader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+
5
+ #Data loader
6
+ class DataLoader_1:
7
+ def __init__(self, B, T, process_rank, num_processes, split, master_process):
8
+ self.B = B
9
+ self.T = T
10
+ self.process_rank = process_rank
11
+ self.num_processes = num_processes
12
+ assert split in {'train', 'val'}
13
+
14
+
15
+ data_root = "data/edu_fineweb10B"
16
+ shards = os.listdir(data_root)
17
+ shards = [s for s in shards if split in s]
18
+ shards = sorted(shards)
19
+ shards = [os.path.join(data_root, s) for s in shards]
20
+ self.shards = shards
21
+ assert len(shards)> 0, f"no shards found for split {split}"
22
+ if master_process:
23
+ print(f"found {len(shards)} shards for split {split}")
24
+ self.reset()
25
+
26
+ def load_tokens(self, filename):
27
+ npt = np.load(filename)
28
+ npt = npt.astype(np.int32)
29
+ ptt = torch.tensor(npt, dtype=torch.long)
30
+ return ptt
31
+
32
+
33
+ def reset(self):
34
+ #state, init at shard 0
35
+ self.current_shard = 0
36
+ self.tokens = self.load_tokens(self.shards[self.current_shard])
37
+ self.current_position = self.B * self.T * self.process_rank
38
+
39
+ def next_batch(self):
40
+ B, T = self.B, self.T
41
+ buf = self.tokens[self.current_position:self.current_position + B*T+1]
42
+ x = (buf[:-1]).view(B,T)
43
+ y = (buf[1:]).view(B,T)
44
+
45
+ self.current_position += B * T * self.num_processes
46
+
47
+ if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
48
+ self.current_shard = (self.current_shard + 1) % len(self.shards)
49
+ self.tokens = self.load_tokens(self.shards[self.current_shard])
50
+ self.current_position = B * T * self.process_rank
51
+ return x, y
52
+
53
+
model_core/model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ import inspect
6
+ from .attention import CasualSelfAttention
7
+
8
+ class MLP(nn.Module):
9
+
10
+ def __init__(self, config):
11
+ super().__init__()
12
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
13
+ self.gelu = nn.GELU(approximate='tanh')
14
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
15
+ self.c_proj.NANOGPT_SCALE_INIT = 1
16
+
17
+ def forward(self, x):
18
+ x = self.c_fc(x)
19
+ x = self.gelu(x)
20
+ x = self.c_proj(x)
21
+ return x
22
+
23
+
24
+ class Block(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.ln_1 = nn.LayerNorm(config.n_embd)
28
+ self.attn = CasualSelfAttention(config)
29
+ self.ln_2 = nn.LayerNorm(config.n_embd)
30
+ self.mlp = MLP(config)
31
+
32
+ def forward(self, x):
33
+ x = x + self.attn(self.ln_1(x))
34
+ x = x + self.mlp(self.ln_2(x))
35
+ return x
36
+
37
+
38
+ @dataclass
39
+ class GPTConfig:
40
+ block_size: int = 1024 #max sequence length
41
+ vocab_size: int = 50257 #number of tokens: 50000 BPE merges + 256 byte tokens +1 special token which is endoftext
42
+ n_layer: int = 12 #number of layers
43
+ n_head: int = 12 #number of heads
44
+ n_embd: int = 768 #embedding dimensions
45
+
46
+
47
+ class GPT(nn.Module):
48
+ def __init__(self, config):
49
+ super().__init__()
50
+ self.config = config
51
+
52
+ self.transformer = nn.ModuleDict(dict(
53
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
54
+ wpe = nn.Embedding(config.block_size, config.n_embd),
55
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
56
+ ln_f = nn.LayerNorm(config.n_embd),
57
+ ))
58
+
59
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
60
+
61
+ #Weight sharing scheme
62
+ self.transformer.wte.weight = self.lm_head.weight
63
+
64
+ # init params
65
+ self.apply(self._init_weights)
66
+
67
+ def _init_weights(self, module):
68
+ if isinstance(module, nn.Linear):
69
+ std = 0.02
70
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
71
+ std *= (2 * self.config.n_layer) ** -0.5
72
+ torch.nn.init.normal_(module.weight, mean = 0.0, std=std)
73
+ if module.bias is not None:
74
+ torch.nn.init.zeros_(module.bias)
75
+ elif isinstance(module, nn.Embedding):
76
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
77
+
78
+ def forward(self, idx, targets=None):
79
+ B, T = idx.size()
80
+ assert T <=self.config.block_size, f"Cannot forward sequence of length {T} ,block size is only {self.config.block_size}"
81
+
82
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
83
+ pos_emb = self.transformer.wpe(pos)
84
+ tok_emb = self.transformer.wte(idx)
85
+ x = tok_emb + pos_emb
86
+
87
+ for block in self.transformer.h:
88
+ x = block(x)
89
+
90
+ x = self.transformer.ln_f(x)
91
+ logits = self.lm_head(x) #(B, T, vocab_size)
92
+ loss = None
93
+ if targets is not None:
94
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
95
+
96
+ return logits, loss
97
+
98
+ def configure_optimizers(self, weight_decay, learning_rate, device_type, master_process):
99
+ param_dict = {pn:p for pn, p in self.named_parameters()}
100
+ param_dict = {pn:p for pn, p in param_dict.items() if p.requires_grad}
101
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
102
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
103
+ optim_groups = [{'params':decay_params, ' weight_decay': weight_decay},
104
+ {'params':nodecay_params, 'weight_decay': 0.0}
105
+ ]
106
+ num_decay_params = sum(p.numel() for p in decay_params)
107
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
108
+ if master_process:
109
+ print(f"num decayed parameters tensors: {len(decay_params)}, with{num_decay_params}:parameters")
110
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
111
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
112
+ use_fused = fused_available and device_type == "cuda"
113
+ if master_process:
114
+ print(f"using fused AdamW: {use_fused}")
115
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9,0.95), eps=1e-8, fused=use_fused)
116
+ return optimizer
117
+
model_core/training.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.distributed import init_process_group, destroy_process_group
2
+ from torch.nn.parallel import DistributedDataParallel as DDP
3
+ import torch.distributed as dist
4
+ import os
5
+ import torch
6
+ import time
7
+ import json
8
+ import math
9
+ from .model import GPT,GPTConfig
10
+
11
+
12
+ def train_memgpt(config_path,dataloader_class=None):
13
+
14
+ with open(config_path,'r') as f:
15
+ cfg = json.load(f)
16
+
17
+ model_cfg_params = cfg['model']
18
+ train_cfg_params = cfg['training']
19
+
20
+ ddp = int(os.environ.get('RANK', -1)) != -1
21
+ if ddp:
22
+ assert torch.cuda.is_available()
23
+ init_process_group(backend='nccl')
24
+ ddp_rank = int(os.environ['RANK'])
25
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
26
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
27
+ device = f"cuda:{ddp_local_rank}"
28
+ torch.cuda.set_device(device)
29
+ master_process = ddp_rank == 0
30
+ else:
31
+ ddp_rank = 0
32
+ ddp_local_rank = 0
33
+ ddp_world_size = 1
34
+ master_process = True
35
+ device = 'cpu'
36
+ if torch.cuda.is_available():
37
+ device = 'cuda'
38
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
39
+ device = "mps"
40
+ if master_process:
41
+ print(f"Using device: {device}")
42
+
43
+ device_type = "cuda" if device.startswith("cuda") else "cpu"
44
+
45
+ torch.manual_seed(1337)
46
+ if torch.cuda.is_available():
47
+ torch.cuda.manual_seed(1337)
48
+
49
+
50
+
51
+ total_batch_size = train_cfg_params['total_batch_size']
52
+ B = train_cfg_params['B']
53
+ T = train_cfg_params['T']
54
+ max_steps = train_cfg_params['max_steps']
55
+ log_dir = train_cfg_params['log_dir']
56
+ max_lr = train_cfg_params['max_lr']
57
+ min_lr = train_cfg_params['min_lr']
58
+ warmup_steps = train_cfg_params['warmup_steps']
59
+ weight_decay = train_cfg_params['weight_decay']
60
+ base_learning_rate = train_cfg_params['learning_rate']
61
+
62
+ assert total_batch_size % (B * T * ddp_world_size) == 0
63
+ grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
64
+ if master_process:
65
+ print(f"Total desired batch size: {total_batch_size}")
66
+ print(f"Calculated gradient accumulation steps: {grad_accum_steps}")
67
+
68
+ train_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train",master_process=master_process)
69
+ val_loader = dataloader_class(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val",master_process=master_process)
70
+
71
+ torch.set_float32_matmul_precision('high')
72
+
73
+ # Create Model
74
+ model = GPT(GPTConfig(**model_cfg_params))
75
+ model.to(device)
76
+ use_compile = True
77
+ if use_compile:
78
+ model = torch.compile(model)
79
+ if ddp:
80
+ model = DDP(model, device_ids=[ddp_local_rank])
81
+ raw_model = model.module if ddp else model
82
+
83
+ def get_lr(it):
84
+ if it < warmup_steps:
85
+ return max_lr * (it + 1) / warmup_steps
86
+ if it > max_steps:
87
+ return min_lr
88
+ decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
89
+ assert 0 <= decay_ratio <= 1
90
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
91
+ return min_lr + coeff * (max_lr - min_lr)
92
+
93
+ optimizer = raw_model.configure_optimizers(weight_decay=weight_decay, learning_rate=base_learning_rate, device_type=device_type, master_process=master_process)
94
+
95
+ os.makedirs(log_dir, exist_ok=True)
96
+ log_file = os.path.join(log_dir, "log.txt")
97
+ with open(log_file, "w") as f:
98
+ pass
99
+
100
+ for step in range(max_steps):
101
+ t0 = time.time()
102
+ last_step = (step == max_steps - 1)
103
+
104
+ if step % 350 == 0 or last_step:
105
+ model.eval()
106
+ val_loader.reset()
107
+ with torch.no_grad():
108
+ val_loss_accum = 0.0
109
+ val_loss_steps = 20
110
+ for _ in range(val_loss_steps):
111
+ x, y = val_loader.next_batch()
112
+ x, y = x.to(device), y.to(device)
113
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
114
+ logits, loss = model(x, y)
115
+ loss = loss / val_loss_steps
116
+ val_loss_accum += loss.detach()
117
+ if ddp:
118
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
119
+ if master_process:
120
+ print(f"Validation loss: {val_loss_accum.item():.4f}")
121
+ with open(log_file, "a") as f:
122
+ f.write(f"{step} val {val_loss_accum.item():.4f}\n")
123
+
124
+ checkpoint_name = f"model_final.pt" if last_step else f"model_{step:05d}.pt"
125
+ checkpoint_path = os.path.join(log_dir, checkpoint_name)
126
+
127
+ checkpoint = {
128
+ 'model': raw_model.state_dict(),
129
+ 'optimizer': optimizer.state_dict(),
130
+ 'step': step,
131
+ 'val_loss': val_loss_accum.item(),
132
+ 'config': raw_model.config
133
+ }
134
+ torch.save(checkpoint, checkpoint_path)
135
+
136
+
137
+ model.train()
138
+ optimizer.zero_grad()
139
+ loss_accum = 0.0
140
+ for micro_step in range(grad_accum_steps):
141
+ x, y = train_loader.next_batch()
142
+ x, y = x.to(device), y.to(device)
143
+ if ddp:
144
+ model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
145
+ with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
146
+ logits, loss = model(x, y)
147
+ loss = loss / grad_accum_steps
148
+ loss_accum += loss.detach()
149
+ loss.backward()
150
+
151
+ if ddp:
152
+ dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
153
+
154
+ norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
155
+ lr = get_lr(step)
156
+ for param_group in optimizer.param_groups:
157
+ param_group['lr'] = lr
158
+ optimizer.step()
159
+ if device_type == 'cuda':
160
+ torch.cuda.synchronize()
161
+ t1 = time.time()
162
+ dt = (t1 - t0) * 1000
163
+ tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
164
+ tokens_per_sec = tokens_processed / dt
165
+ if master_process:
166
+ print(f"Step:{step:5d} | Loss: {loss_accum.item():.6f} | lr: {lr:.4e} | Norm:{norm:.4f} | dt: {dt:.2f}ms | Tok/sec: {tokens_per_sec:.2f}")
167
+ with open(log_file, 'a') as f:
168
+ f.write(f"{step} train {loss_accum.item():.6f}\n")
169
+
170
+ if ddp:
171
+ destroy_process_group()
requirement.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ safetensors==0.5.3
4
+ tiktoken==0.9.0
5
+ tokenizers==0.21.1
6
+ transformers==4.50.1
7
+ tqdm==4.67.1
8
+ requests==2.32.3
9
+ numpy<1.27,>=1.22
10
+ torch==2.3.1+cu121
rough_work.py ADDED
File without changes
scripts/evaluate.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #To run all evaluation at once
2
+ #Code yet to be added
scripts/generate.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import tiktoken
4
+ from model import GPT
5
+
6
+ def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
7
+ model.eval()
8
+ enc = tiktoken.get_encoding('gpt2')
9
+ tokens = enc.encode(prompt)
10
+ tokens = torch.tensor(tokens, dtype=torch.long)
11
+ tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
12
+ xgen = tokens.to(device)
13
+ sample_rng = torch.Generator(device=device)
14
+ sample_rng.manual_seed(42)
15
+
16
+ while xgen.size(1) < max_length:
17
+ with torch.no_grad():
18
+ logits, loss = model(xgen) # (B, T, vocab_size)
19
+ logits = logits[:, -1, :] # (B, vocab_size)
20
+ probs = F.softmax(logits, dim=-1)
21
+ topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
22
+ ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
23
+ xcol = torch.gather(topk_indices, -1, ix)
24
+ xgen = torch.cat((xgen, xcol), dim=1)
25
+
26
+ generated_texts = []
27
+ for i in range(num_return_sequences):
28
+ tokens = xgen[i, :max_length].tolist()
29
+ decoded = enc.decode(tokens)
30
+ generated_texts.append(decoded)
31
+ print(f"Sample {i + 1}: {decoded}")
32
+
33
+
34
+ return generated_texts
35
+
36
+
37
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
+ print(f"running with {device}")
39
+
40
+
41
+ checkpoint_path = 'log/model_final.pt'
42
+
43
+ print(f"Loading checkpoint from {checkpoint_path}")
44
+ checkpoint = torch.load(checkpoint_path,map_location=device)
45
+ model_config = checkpoint['config']
46
+ model_config.vocab_size = 50304
47
+ model = GPT(model_config)
48
+
49
+
50
+ model.load_state_dict(checkpoint['model'])
51
+ model.to(device)
52
+
53
+
54
+
55
+ prompt = "Hello, I'm a language model,"
56
+
57
+ generated_texts = generate_text(
58
+ model=model,
59
+ prompt=prompt,
60
+ num_return_sequences=4,
61
+ max_length=32,
62
+ device=device
63
+ )
scripts/train.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ from model_core.training import train_memgpt
6
+ from model_core.dataloader import DataLoader_1
7
+
8
+ if __name__ == "__main__":
9
+ config_path = "configs/config.json"
10
+ print("Training starter")
11
+ train_memgpt(config_path=config_path,dataloader_class=DataLoader_1)