EvoConvo / evo_model.py
HemanM's picture
Update evo_model.py
e7984f7 verified
raw
history blame
1.06 kB
import torch
import torch.nn as nn
class EvoDecoderModel(nn.Module):
def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dim_feedforward=1024, dropout=0.1):
super(EvoDecoderModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.zeros(1, 512, d_model)) # max length 512
decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.output_layer = nn.Linear(d_model, vocab_size)
def forward(self, tgt, memory=None):
seq_len = tgt.size(1)
embedded = self.embedding(tgt) + self.pos_embedding[:, :seq_len, :]
# If no memory is provided, use dummy memory filled with zeros
if memory is None:
memory = torch.zeros_like(embedded)
output = self.transformer_decoder(embedded.transpose(0, 1), memory.transpose(0, 1))
return self.output_layer(output.transpose(0, 1))