File size: 1,055 Bytes
0bd71c9
7e4e9e0
ccff75d
 
e7984f7
1070f67
e7984f7
 
 
 
 
0bd71c9
e7984f7
 
 
0bd71c9
e7984f7
 
 
0bd71c9
e7984f7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn

class EvoDecoderModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1):
        super(EvoDecoderModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model))  # max length 512
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, memory=None):
        seq_len = tgt.size(1)
        embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :]

        # If no memory is provided, use dummy memory filled with zeros
        if memory is None:
            memory = torch.zeros_like(embedded)

        output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1))
        return self.output_layer(output.transpose(0, 1))