a-ragab-h-m commited on
Commit
615dabe
·
verified ·
1 Parent(s): e691424

Update nets/decoder.py

Browse files
Files changed (1) hide show
  1. nets/decoder.py +2 -2
nets/decoder.py CHANGED
@@ -79,10 +79,10 @@ class Attention(nn.Module):
79
 
80
 
81
  class Decoder(nn.Module):
82
- def __init__(self, num_heads, input_size, embedding_size, softmax_output=False, C=10):
83
  super().__init__()
84
  self.embedding_size = embedding_size
85
- self.initial_embedding = nn.Linear(input_size, embedding_size)
86
  self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
87
  self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
88
 
 
79
 
80
 
81
  class Decoder(nn.Module):
82
+ def __init__(self, num_heads, embedding_size, decoder_input_size, softmax_output=False, C=10):
83
  super().__init__()
84
  self.embedding_size = embedding_size
85
+ self.initial_embedding = nn.Linear(decoder_input_size, embedding_size)
86
  self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
87
  self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
88