|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class EvoDecoderBlock(nn.Module): |
|
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1): |
|
super(EvoDecoderBlock, self).__init__() |
|
self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
|
self.qkv_proj = nn.Linear(d_model, d_model * 3) |
|
self.out_proj = nn.Linear(d_model, d_model) |
|
self.ffn = nn.Sequential( |
|
nn.Linear(d_model, dim_feedforward), |
|
nn.ReLU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(dim_feedforward, d_model), |
|
) |
|
self.ln1 = nn.LayerNorm(d_model) |
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
|
def forward(self, x): |
|
|
|
qkv = self.qkv_proj(x) |
|
q, k, v = torch.chunk(qkv, 3, dim=-1) |
|
attn_output, _ = self.attn(q, k, v) |
|
x = self.ln1(x + self.out_proj(attn_output)) |
|
|
|
|
|
x = self.ln2(x + self.ffn(x)) |
|
return x |
|
|
|
class EvoDecoderModel(nn.Module): |
|
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=512): |
|
super(EvoDecoderModel, self).__init__() |
|
self.token_emb = nn.Embedding(vocab_size, d_model) |
|
self.pos_emb = nn.Embedding(max_len, d_model) |
|
self.blocks = nn.ModuleList([ |
|
EvoDecoderBlock(d_model, nhead, dim_feedforward, dropout) |
|
for _ in range(num_layers) |
|
]) |
|
self.ln_f = nn.LayerNorm(d_model) |
|
self.fc_out = nn.Linear(d_model, vocab_size) |
|
|
|
def forward(self, x): |
|
device = x.device |
|
seq_len = x.size(1) |
|
pos = torch.arange(0, seq_len, device=device).unsqueeze(0) |
|
x = self.token_emb(x) + self.pos_emb(pos) |
|
|
|
for block in self.blocks: |
|
x = block(x) |
|
|
|
x = self.ln_f(x) |
|
return self.fc_out(x) |
|
|