a-ragab-h-m commited on
Commit
7494603
·
verified ·
1 Parent(s): 9ebaa0c

Update nets/model.py

Browse files
Files changed (1) hide show
  1. nets/model.py +7 -10
nets/model.py CHANGED
@@ -4,6 +4,7 @@ from nets.decoder import Decoder
4
  from nets.projections import Projections
5
  from nets.encoder import Encoder
6
 
 
7
  class Model(nn.Module):
8
  def __init__(self, input_size, embedding_size,
9
  num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
@@ -11,34 +12,30 @@ class Model(nn.Module):
11
 
12
  self.embedding_size = embedding_size
13
 
14
- # ---------------------- Encoder ------------------------
15
- # Takes input features like: [lng, lat, start_time, end_time, depot_flag]
16
  self.encoder = Encoder(
17
  n_heads=num_heads,
18
  embed_dim=embedding_size,
19
  n_layers=num_layers,
20
  feed_forward_hidden=ff_hidden,
21
- node_dim=input_size # ← هذا هو عدد الخصائص للعقدة (5 غالبًا)
22
  )
23
 
24
- # ---------------------- Decoder ------------------------
25
- # سابقًا كانت ثابتة 4*embedding + 1, هنا نخليها أكثر مرونة
26
- decoder_input_dim = 3 * embedding_size + 2 # [encoder_output, mask, fleet_info] تقريبًا
27
-
28
  self.decoder = Decoder(
29
  input_size=decoder_input_dim,
30
  embedding_size=embedding_size,
31
  num_heads=num_heads
32
  )
33
 
34
- # ---------------------- Attention Projections ------------------------
35
  self.projections = Projections(
36
  n_heads=num_heads,
37
  embed_dim=embedding_size
38
  )
39
 
40
- # ---------------------- Fleet Attention ------------------------
41
- # هذه تستخدم بيانات الأسطول مثل [start_time, location_embedding]
42
  self.fleet_attention = Encoder(
43
  n_heads=num_heads,
44
  embed_dim=embedding_size,
 
4
  from nets.projections import Projections
5
  from nets.encoder import Encoder
6
 
7
+
8
  class Model(nn.Module):
9
  def __init__(self, input_size, embedding_size,
10
  num_heads=8, num_layers=4, ff_hidden=250, *args, **kwargs):
 
12
 
13
  self.embedding_size = embedding_size
14
 
15
+ # ----------- Encoder -----------
 
16
  self.encoder = Encoder(
17
  n_heads=num_heads,
18
  embed_dim=embedding_size,
19
  n_layers=num_layers,
20
  feed_forward_hidden=ff_hidden,
21
+ node_dim=input_size
22
  )
23
 
24
+ # ----------- Decoder -----------
25
+ decoder_input_dim = 3 * embedding_size + 2 # example: [encoder output] + [mask info] + [fleet info]
 
 
26
  self.decoder = Decoder(
27
  input_size=decoder_input_dim,
28
  embedding_size=embedding_size,
29
  num_heads=num_heads
30
  )
31
 
32
+ # ----------- Attention Projections -----------
33
  self.projections = Projections(
34
  n_heads=num_heads,
35
  embed_dim=embedding_size
36
  )
37
 
38
+ # ----------- Fleet Attention Encoder (Optional) -----------
 
39
  self.fleet_attention = Encoder(
40
  n_heads=num_heads,
41
  embed_dim=embedding_size,