Update nets/model.py
Browse files- nets/model.py +7 -10
nets/model.py
CHANGED
@@ -4,6 +4,7 @@ from nets.decoder import Decoder
|
|
4 |
from nets.projections import Projections
|
5 |
from nets.encoder import Encoder
|
6 |
|
|
|
7 |
class Model(nn.Module):
|
8 |
def __init__(self, input_size, embedding_size,
|
9 |
num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
|
@@ -11,34 +12,30 @@ class Model(nn.Module):
|
|
11 |
|
12 |
self.embedding_size = embedding_size
|
13 |
|
14 |
-
#
|
15 |
-
# Takes input features like: [lng, lat, start_time, end_time, depot_flag]
|
16 |
self.encoder = Encoder(
|
17 |
n_heads=num_heads,
|
18 |
embed_dim=embedding_size,
|
19 |
n_layers=num_layers,
|
20 |
feed_forward_hidden=ff_hidden,
|
21 |
-
node_dim=input_size
|
22 |
)
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
decoder_input_dim = 3 * embedding_size + 2 # [encoder_output, mask, fleet_info] تقريبًا
|
27 |
-
|
28 |
self.decoder = Decoder(
|
29 |
input_size=decoder_input_dim,
|
30 |
embedding_size=embedding_size,
|
31 |
num_heads=num_heads
|
32 |
)
|
33 |
|
34 |
-
#
|
35 |
self.projections = Projections(
|
36 |
n_heads=num_heads,
|
37 |
embed_dim=embedding_size
|
38 |
)
|
39 |
|
40 |
-
#
|
41 |
-
# هذه تستخدم بيانات الأسطول مثل [start_time, location_embedding]
|
42 |
self.fleet_attention = Encoder(
|
43 |
n_heads=num_heads,
|
44 |
embed_dim=embedding_size,
|
|
|
4 |
from nets.projections import Projections
|
5 |
from nets.encoder import Encoder
|
6 |
|
7 |
+
|
8 |
class Model(nn.Module):
|
9 |
def __init__(self, input_size, embedding_size,
|
10 |
num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
|
|
|
12 |
|
13 |
self.embedding_size = embedding_size
|
14 |
|
15 |
+
# ----------- Encoder -----------
|
|
|
16 |
self.encoder = Encoder(
|
17 |
n_heads=num_heads,
|
18 |
embed_dim=embedding_size,
|
19 |
n_layers=num_layers,
|
20 |
feed_forward_hidden=ff_hidden,
|
21 |
+
node_dim=input_size
|
22 |
)
|
23 |
|
24 |
+
# ----------- Decoder -----------
|
25 |
+
decoder_input_dim = 3 * embedding_size + 2 # example: [encoder output] + [mask info] + [fleet info]
|
|
|
|
|
26 |
self.decoder = Decoder(
|
27 |
input_size=decoder_input_dim,
|
28 |
embedding_size=embedding_size,
|
29 |
num_heads=num_heads
|
30 |
)
|
31 |
|
32 |
+
# ----------- Attention Projections -----------
|
33 |
self.projections = Projections(
|
34 |
n_heads=num_heads,
|
35 |
embed_dim=embedding_size
|
36 |
)
|
37 |
|
38 |
+
# ----------- Fleet Attention Encoder (Optional) -----------
|
|
|
39 |
self.fleet_attention = Encoder(
|
40 |
n_heads=num_heads,
|
41 |
embed_dim=embedding_size,
|