File size: 688 Bytes
7e4e9e0 ccff75d defaa9b 1070f67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch.nn as nn
class EvoDecoderModel(nn.Module):
def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1):
super(EvoDecoderModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(self, tgt, memory):
embedded = self.embedding(tgt)
output = self.transformer_decoder(embedded, memory)
return self.output_layer(output)
|