Update nets/decoder.py
Browse files- nets/decoder.py +1 -1
nets/decoder.py
CHANGED
@@ -82,7 +82,7 @@ class Decoder(nn.Module):
|
|
82 |
def __init__(self, num_heads, embedding_size, decoder_input_size, softmax_output=False, C=10):
|
83 |
super().__init__()
|
84 |
self.embedding_size = embedding_size
|
85 |
-
self.initial_embedding = nn.Linear(decoder_input_size
|
86 |
self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
|
87 |
self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
|
88 |
|
|
|
82 |
def __init__(self, num_heads, embedding_size, decoder_input_size, softmax_output=False, C=10):
|
83 |
super().__init__()
|
84 |
self.embedding_size = embedding_size
|
85 |
+
self.initial_embedding = nn.Linear(decoder_input_size - 1, embedding_size)
|
86 |
self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
|
87 |
self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
|
88 |
|