File size: 688 Bytes
7e4e9e0
ccff75d
 
defaa9b
1070f67
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.nn as nn

class EvoDecoderModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
        super(EvoDecoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
    
    def forward(self, tgt, memory):
        embedded = self.embedding(tgt)
        output = self.transformer_decoder(embedded, memory)
        return self.output_layer(output)