EvoConvo / evo_model.py
HemanM's picture
Update evo_model.py
0bd71c9 verified
raw
history blame
1.94 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class EvoDecoderBlock(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
super(EvoDecoderBlock, self).__init__()
self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.qkv_proj = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention with skip connection
qkv = self.qkv_proj(x)
q, k, v = torch.chunk(qkv, 3, dim=-1)
attn_output, _ = self.attn(q, k, v)
x = self.ln1(x + self.out_proj(attn_output))
# Feedforward with skip connection
x = self.ln2(x + self.ffn(x))
return x
class EvoDecoderModel(nn.Module):
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, max_len=512):
super(EvoDecoderModel, self).__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.blocks = nn.ModuleList([
EvoDecoderBlock(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x):
device = x.device
seq_len = x.size(1)
pos = torch.arange(0, seq_len, device=device).unsqueeze(0)
x = self.token_emb(x) + self.pos_emb(pos)
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
return self.fc_out(x)