|
import torch |
|
import torch.nn as nn |
|
|
|
class EvoDecoderModel(nn.Module): |
|
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1): |
|
super(EvoDecoderModel, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, d_model) |
|
self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model)) |
|
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) |
|
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers) |
|
self.output_layer = nn.Linear(d_model, vocab_size) |
|
|
|
def forward(self, tgt, memory=None): |
|
seq_len = tgt.size(1) |
|
embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :] |
|
|
|
|
|
if memory is None: |
|
memory = torch.zeros_like(embedded) |
|
|
|
output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1)) |
|
return self.output_layer(output.transpose(0, 1)) |
|
|