a-ragab-h-m commited on
Commit
c646511
·
verified ·
1 Parent(s): 615dabe

Update nets/model.py

Browse files
Files changed (1) hide show
  1. nets/model.py +2 -2
nets/model.py CHANGED
@@ -7,6 +7,7 @@ from nets.encoder import Encoder
7
 
8
  class Model(nn.Module):
9
  def __init__(self, input_size, embedding_size,
 
10
  num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
11
  super().__init__()
12
 
@@ -22,9 +23,8 @@ class Model(nn.Module):
22
  )
23
 
24
  # ----------- Decoder -----------
25
- decoder_input_dim = 3 * embedding_size + 2 # example: [encoder output] + [mask info] + [fleet info]
26
  self.decoder = Decoder(
27
- input_size=decoder_input_dim,
28
  embedding_size=embedding_size,
29
  num_heads=num_heads
30
  )
 
7
 
8
  class Model(nn.Module):
9
  def __init__(self, input_size, embedding_size,
10
+ decoder_input_size,
11
  num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
12
  super().__init__()
13
 
 
23
  )
24
 
25
  # ----------- Decoder -----------
 
26
  self.decoder = Decoder(
27
+ decoder_input_size=decoder_input_size,
28
  embedding_size=embedding_size,
29
  num_heads=num_heads
30
  )