Spaces:
Sleeping
Sleeping
Create evo_core_gpt.py
Browse files- evo_core_gpt.py +99 -0
evo_core_gpt.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class SelfAttention(nn.Module):
|
8 |
+
def __init__(self, d_model: int, num_heads: int):
|
9 |
+
super().__init__()
|
10 |
+
assert d_model % num_heads == 0
|
11 |
+
self.d_model = d_model
|
12 |
+
self.num_heads = num_heads
|
13 |
+
self.head_dim = d_model // num_heads
|
14 |
+
# names match your state_dict:
|
15 |
+
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True)
|
16 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=True)
|
17 |
+
|
18 |
+
def forward(self, x, attn_mask=None):
|
19 |
+
B, T, C = x.shape
|
20 |
+
qkv = self.qkv_proj(x) # (B,T,3C)
|
21 |
+
q, k, v = qkv.chunk(3, dim=-1) # each (B,T,C)
|
22 |
+
# reshape to heads
|
23 |
+
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
24 |
+
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
25 |
+
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B,H,T,Hd)
|
26 |
+
# scaled dot-product attention with causal mask
|
27 |
+
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B,H,T,T)
|
28 |
+
# causal mask
|
29 |
+
causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril()
|
30 |
+
att = att.masked_fill(~causal, float("-inf"))
|
31 |
+
if attn_mask is not None:
|
32 |
+
att = att + attn_mask
|
33 |
+
att = F.softmax(att, dim=-1)
|
34 |
+
y = att @ v # (B,H,T,Hd)
|
35 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B,T,C)
|
36 |
+
return self.out_proj(y)
|
37 |
+
|
38 |
+
class FFN(nn.Module):
|
39 |
+
def __init__(self, d_model: int, mult: int = 4):
|
40 |
+
super().__init__()
|
41 |
+
# match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between
|
42 |
+
self.net = nn.Sequential(
|
43 |
+
nn.Linear(d_model, mult * d_model, bias=True),
|
44 |
+
nn.GELU(),
|
45 |
+
nn.Dropout(p=0.0),
|
46 |
+
nn.Linear(mult * d_model, d_model, bias=True),
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.net(x)
|
51 |
+
|
52 |
+
class Block(nn.Module):
|
53 |
+
def __init__(self, d_model: int, num_heads: int):
|
54 |
+
super().__init__()
|
55 |
+
# names must match: blocks.N.ln1/ln2, .attn.*, .ffn.*
|
56 |
+
self.ln1 = nn.LayerNorm(d_model)
|
57 |
+
self.attn = SelfAttention(d_model, num_heads)
|
58 |
+
self.ln2 = nn.LayerNorm(d_model)
|
59 |
+
self.ffn = FFN(d_model)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = x + self.attn(self.ln1(x))
|
63 |
+
x = x + self.ffn(self.ln2(x))
|
64 |
+
return x
|
65 |
+
|
66 |
+
class EvoGPT(nn.Module):
|
67 |
+
def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int):
|
68 |
+
super().__init__()
|
69 |
+
# names must match your checkpoint:
|
70 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
71 |
+
self.pos_emb = nn.Embedding(n_positions, d_model)
|
72 |
+
self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)])
|
73 |
+
# (no final ln_f in your keys; tokens are projected by tying to token_emb)
|
74 |
+
self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False)
|
75 |
+
|
76 |
+
def forward(self, input_ids):
|
77 |
+
B, T = input_ids.shape
|
78 |
+
pos = self.pos_idx[:T]
|
79 |
+
x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :]
|
80 |
+
for blk in self.blocks:
|
81 |
+
x = blk(x)
|
82 |
+
# weight tying: logits = x @ token_emb^T
|
83 |
+
logits = x @ self.token_emb.weight.t()
|
84 |
+
return logits
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None):
|
88 |
+
self.eval()
|
89 |
+
for _ in range(int(max_new_tokens)):
|
90 |
+
logits = self.forward(input_ids)[:, -1, :] # (B, V)
|
91 |
+
if temperature and temperature > 0.0:
|
92 |
+
probs = F.softmax(logits / max(0.01, temperature), dim=-1)
|
93 |
+
next_id = torch.multinomial(probs, num_samples=1)
|
94 |
+
else:
|
95 |
+
next_id = torch.argmax(logits, dim=-1, keepdim=True)
|
96 |
+
input_ids = torch.cat([input_ids, next_id], dim=1)
|
97 |
+
if eos_token_id is not None and (next_id == eos_token_id).all():
|
98 |
+
break
|
99 |
+
return input_ids
|