File size: 1,055 Bytes
0bd71c9 7e4e9e0 ccff75d e7984f7 1070f67 e7984f7 0bd71c9 e7984f7 0bd71c9 e7984f7 0bd71c9 e7984f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
class EvoDecoderModel(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1):
super(EvoDecoderModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model)) # max length 512
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(self, tgt, memory=None):
seq_len = tgt.size(1)
embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :]
# If no memory is provided, use dummy memory filled with zeros
if memory is None:
memory = torch.zeros_like(embedded)
output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1))
return self.output_layer(output.transpose(0, 1))
|