Update nets/decoder.py
Browse files- nets/decoder.py +2 -2
nets/decoder.py
CHANGED
@@ -79,10 +79,10 @@ class Attention(nn.Module):
|
|
79 |
|
80 |
|
81 |
class Decoder(nn.Module):
|
82 |
-
def __init__(self, num_heads,
|
83 |
super().__init__()
|
84 |
self.embedding_size = embedding_size
|
85 |
-
self.initial_embedding = nn.Linear(
|
86 |
self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
|
87 |
self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
|
88 |
|
|
|
79 |
|
80 |
|
81 |
class Decoder(nn.Module):
|
82 |
+
def __init__(self, num_heads, embedding_size, decoder_input_size, softmax_output=False, C=10):
|
83 |
super().__init__()
|
84 |
self.embedding_size = embedding_size
|
85 |
+
self.initial_embedding = nn.Linear(decoder_input_size, embedding_size)
|
86 |
self.attention = Attention(n_heads=num_heads, input_dim=embedding_size, embed_dim=embedding_size)
|
87 |
self.classifier_output = ClassifierOutput(embedding_size=embedding_size, C=C, softmax_output=softmax_output)
|
88 |
|