import torch import torch.nn as nn from nets.decoder import Decoder from nets.projections import Projections from nets.encoder import Encoder class Model(nn.Module): def __init__(self, input_size, embedding_size, decoder_input_size, num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs): super().__init__() self.embedding_size = embedding_size # ----------- Encoder ----------- self.encoder = Encoder( n_heads=num_heads, embed_dim=embedding_size, n_layers=num_layers, feed_forward_hidden=ff_hidden, node_dim=input_size ) # ----------- Decoder ----------- self.decoder = Decoder( decoder_input_size=decoder_input_size, embedding_size=embedding_size, num_heads=num_heads ) # ----------- Attention Projections ----------- self.projections = Projections( n_heads=num_heads, embed_dim=embedding_size ) # ----------- Fleet Attention Encoder (Optional) ----------- self.fleet_attention = Encoder( n_heads=num_heads, embed_dim=embedding_size, n_layers=1, feed_forward_hidden=ff_hidden, node_dim=embedding_size + 1 )