HemanM commited on
Commit
ecf4bc9
·
verified ·
1 Parent(s): c487bf4

Create evo_core_gpt.py

Browse files
Files changed (1) hide show
  1. evo_core_gpt.py +99 -0
evo_core_gpt.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class SelfAttention(nn.Module):
8
+ def __init__(self, d_model: int, num_heads: int):
9
+ super().__init__()
10
+ assert d_model % num_heads == 0
11
+ self.d_model = d_model
12
+ self.num_heads = num_heads
13
+ self.head_dim = d_model // num_heads
14
+ # names match your state_dict:
15
+ self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True)
16
+ self.out_proj = nn.Linear(d_model, d_model, bias=True)
17
+
18
+ def forward(self, x, attn_mask=None):
19
+ B, T, C = x.shape
20
+ qkv = self.qkv_proj(x) # (B,T,3C)
21
+ q, k, v = qkv.chunk(3, dim=-1) # each (B,T,C)
22
+ # reshape to heads
23
+ q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
24
+ k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
25
+ v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
26
+ # scaled dot-product attention with causal mask
27
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T)
28
+ # causal mask
29
+ causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril()
30
+ att = att.masked_fill(~causal, float("-inf"))
31
+ if attn_mask is not None:
32
+ att = att + attn_mask
33
+ att = F.softmax(att, dim=-1)
34
+ y = att @ v # (B,H,T,Hd)
35
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # (B,T,C)
36
+ return self.out_proj(y)
37
+
38
+ class FFN(nn.Module):
39
+ def __init__(self, d_model: int, mult: int = 4):
40
+ super().__init__()
41
+ # match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between
42
+ self.net = nn.Sequential(
43
+ nn.Linear(d_model, mult * d_model, bias=True),
44
+ nn.GELU(),
45
+ nn.Dropout(p=0.0),
46
+ nn.Linear(mult * d_model, d_model, bias=True),
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.net(x)
51
+
52
+ class Block(nn.Module):
53
+ def __init__(self, d_model: int, num_heads: int):
54
+ super().__init__()
55
+ # names must match: blocks.N.ln1/ln2, .attn.*, .ffn.*
56
+ self.ln1 = nn.LayerNorm(d_model)
57
+ self.attn = SelfAttention(d_model, num_heads)
58
+ self.ln2 = nn.LayerNorm(d_model)
59
+ self.ffn = FFN(d_model)
60
+
61
+ def forward(self, x):
62
+ x = x + self.attn(self.ln1(x))
63
+ x = x + self.ffn(self.ln2(x))
64
+ return x
65
+
66
+ class EvoGPT(nn.Module):
67
+ def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int):
68
+ super().__init__()
69
+ # names must match your checkpoint:
70
+ self.token_emb = nn.Embedding(vocab_size, d_model)
71
+ self.pos_emb = nn.Embedding(n_positions, d_model)
72
+ self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)])
73
+ # (no final ln_f in your keys; tokens are projected by tying to token_emb)
74
+ self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False)
75
+
76
+ def forward(self, input_ids):
77
+ B, T = input_ids.shape
78
+ pos = self.pos_idx[:T]
79
+ x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :]
80
+ for blk in self.blocks:
81
+ x = blk(x)
82
+ # weight tying: logits = x @ token_emb^T
83
+ logits = x @ self.token_emb.weight.t()
84
+ return logits
85
+
86
+ @torch.no_grad()
87
+ def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None):
88
+ self.eval()
89
+ for _ in range(int(max_new_tokens)):
90
+ logits = self.forward(input_ids)[:, -1, :] # (B, V)
91
+ if temperature and temperature > 0.0:
92
+ probs = F.softmax(logits / max(0.01, temperature), dim=-1)
93
+ next_id = torch.multinomial(probs, num_samples=1)
94
+ else:
95
+ next_id = torch.argmax(logits, dim=-1, keepdim=True)
96
+ input_ids = torch.cat([input_ids, next_id], dim=1)
97
+ if eos_token_id is not None and (next_id == eos_token_id).all():
98
+ break
99
+ return input_ids