File size: 4,237 Bytes
ecf4bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# evo_core_gpt.py — minimal decoder-only LM matching your checkpoint keys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        # names match your state_dict:
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=True)
        self.out_proj = nn.Linear(d_model, d_model, bias=True)

    def forward(self, x, attn_mask=None):
        B, T, C = x.shape
        qkv = self.qkv_proj(x)                           # (B,T,3C)
        q, k, v = qkv.chunk(3, dim=-1)                   # each (B,T,C)
        # reshape to heads
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B,H,T,Hd)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B,H,T,Hd)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B,H,T,Hd)
        # scaled dot-product attention with causal mask
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)       # (B,H,T,T)
        # causal mask
        causal = torch.ones((T, T), device=x.device, dtype=torch.bool).tril()
        att = att.masked_fill(~causal, float("-inf"))
        if attn_mask is not None:
            att = att + attn_mask
        att = F.softmax(att, dim=-1)
        y = att @ v                                                     # (B,H,T,Hd)
        y = y.transpose(1, 2).contiguous().view(B, T, C)                # (B,T,C)
        return self.out_proj(y)

class FFN(nn.Module):
    def __init__(self, d_model: int, mult: int = 4):
        super().__init__()
        # match your names: ffn.net.0 (Linear), .3 (Linear), with GELU in between
        self.net = nn.Sequential(
            nn.Linear(d_model, mult * d_model, bias=True),
            nn.GELU(),
            nn.Dropout(p=0.0),
            nn.Linear(mult * d_model, d_model, bias=True),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        # names must match: blocks.N.ln1/ln2, .attn.*, .ffn.*
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = SelfAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class EvoGPT(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_layers: int, n_positions: int, num_heads: int):
        super().__init__()
        # names must match your checkpoint:
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(n_positions, d_model)
        self.blocks = nn.ModuleList([Block(d_model, num_heads) for _ in range(n_layers)])
        # (no final ln_f in your keys; tokens are projected by tying to token_emb)
        self.register_buffer("pos_idx", torch.arange(0, n_positions).long(), persistent=False)

    def forward(self, input_ids):
        B, T = input_ids.shape
        pos = self.pos_idx[:T]
        x = self.token_emb(input_ids) + self.pos_emb(pos)[None, :, :]
        for blk in self.blocks:
            x = blk(x)
        # weight tying: logits = x @ token_emb^T
        logits = x @ self.token_emb.weight.t()
        return logits

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=128, temperature=0.7, eos_token_id=None):
        self.eval()
        for _ in range(int(max_new_tokens)):
            logits = self.forward(input_ids)[:, -1, :]         # (B, V)
            if temperature and temperature > 0.0:
                probs = F.softmax(logits / max(0.01, temperature), dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
            else:
                next_id = torch.argmax(logits, dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_id], dim=1)
            if eos_token_id is not None and (next_id == eos_token_id).all():
                break
        return input_ids