File size: 2,528 Bytes
32221da
08624da
a9b4cfb
32221da
a9b4cfb
 
32221da
a9b4cfb
08624da
a9b4cfb
 
32221da
 
08624da
a9b4cfb
08624da
a9b4cfb
 
 
 
 
 
08624da
a9b4cfb
 
 
 
32221da
08624da
a9b4cfb
08624da
 
a9b4cfb
08624da
2dccd99
a9b4cfb
08624da
32221da
 
08624da
32221da
a9b4cfb
 
32221da
a9b4cfb
08624da
a9b4cfb
08624da
32221da
 
 
 
 
ccff75d
08624da
2dccd99
32221da
08624da
a9b4cfb
 
 
08624da
 
 
32221da
 
a9b4cfb
 
 
 
 
32221da
08624da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.nhead = nhead
        self.d_model = d_model

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
        k = k.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
        v = v.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (C // self.nhead) ** 0.5
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, dim_feedforward):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Dropout(),  # ✅ Important: was present in the training model
            nn.Linear(dim_feedforward, d_model)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward):
        super().__init__()
        self.attn = SelfAttention(d_model, nhead)
        self.ln1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, dim_feedforward)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

class EvoDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(512, d_model)
        self.blocks = nn.Sequential(*[
            TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        B, T = x.size()
        tok = self.token_emb(x)
        pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T))
        x = tok + pos
        x = self.blocks(x)
        x = self.ln_f(x)
        return self.fc_out(x)