EvoConvo / evo_decoder.py
HemanM's picture
Rename evo_model.py to evo_decoder.py
749bc03 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.nhead = nhead
self.d_model = d_model
def forward(self, x):
B, T, C = x.size()
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
k = k.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
v = v.view(B, T, self.nhead, C // self.nhead).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (C // self.nhead) ** 0.5
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, C)
return self.out_proj(out)
class FeedForward(nn.Module):
def __init__(self, d_model, dim_feedforward):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(), # ✅ Important: was present in the training model
nn.Linear(dim_feedforward, d_model)
)
def forward(self, x):
return self.net(x)
class TransformerBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.attn = SelfAttention(d_model, nhead)
self.ln1 = nn.LayerNorm(d_model)
self.ffn = FeedForward(d_model, dim_feedforward)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
class EvoDecoder(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=512):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(512, d_model)
self.blocks = nn.Sequential(*[
TransformerBlock(d_model, nhead, dim_feedforward) for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
B, T = x.size()
tok = self.token_emb(x)
pos = self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0).expand(B, T))
x = tok + pos
x = self.blocks(x)
x = self.ln_f(x)
return self.fc_out(x)