File size: 2,528 Bytes
32221da 08624da a9b4cfb 32221da a9b4cfb 32221da a9b4cfb 08624da a9b4cfb 32221da 08624da a9b4cfb 08624da a9b4cfb 08624da a9b4cfb 32221da 08624da a9b4cfb 08624da a9b4cfb 08624da 2dccd99 a9b4cfb 08624da 32221da 08624da 32221da a9b4cfb 32221da a9b4cfb 08624da a9b4cfb 08624da 32221da ccff75d 08624da 2dccd99 32221da 08624da a9b4cfb 08624da 32221da a9b4cfb 32221da 08624da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.nhead = nhead
self.d_model = d_model
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
k = k.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
v = v.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (C // self.nhead) ** 0.5
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, d_model, dim_feedforward):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(), # ✅ Important: was present in the training model
nn.Linear(dim_feedforward, d_model)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.attn = SelfAttention(d_model, nhead)
self.ln1 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, dim_feedforward)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class EvoDecoder(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=512):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(512, d_model)
self.blocks = nn.Sequential(*[
TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
B, T = x.size()
tok = self.token_emb(x)
pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T))
x = tok + pos
x = self.blocks(x)
x = self.ln_f(x)
return self.fc_out(x)
|