a-ragab-h-m commited on
Commit
36a9522
·
verified ·
1 Parent(s): f8ecbbc

Update google_solver/scratch.py

Browse files
Files changed (1) hide show
  1. google_solver/scratch.py +20 -14
google_solver/scratch.py CHANGED
@@ -1,27 +1,33 @@
1
-
2
-
3
  from just_time_windows.google_solver.google_model import evaluate_google_model
4
  from just_time_windows.Actor.actor import Actor as NN_Actor
5
  from just_time_windows.build_data import Raw_VRP_Data
6
  from just_time_windows.dataloader import VRP_Dataset
7
 
 
 
 
8
 
9
- dataset = VRP_Dataset(dataset_size=10, num_depots=1, num_nodes=12)
10
-
11
- batch = dataset.get_batch(0, 10)
12
-
13
- nn_actor = NN_Actor(model=None, num_movers=10, num_neighbors_action=1)
14
 
 
 
15
 
16
- nn_output = nn_actor(batch)
17
- time = nn_output['total_time']
18
- arrival_times = nn_output['arrival_times']
 
19
 
 
 
20
 
21
- output = evaluate_google_model(dataset)
 
 
22
 
 
 
23
 
24
 
25
- print(arrival_times)
26
- print(time.mean().item())
27
- print(output.mean().item())
 
 
 
1
  from just_time_windows.google_solver.google_model import evaluate_google_model
2
  from just_time_windows.Actor.actor import Actor as NN_Actor
3
  from just_time_windows.build_data import Raw_VRP_Data
4
  from just_time_windows.dataloader import VRP_Dataset
5
 
6
+ def main():
7
+ # إعداد مجموعة بيانات صغيرة للاختبار
8
+ dataset = VRP_Dataset(dataset_size=10, num_depots=1, num_nodes=12)
9
 
10
+ # استخراج دفعة واحدة للاختبار
11
+ batch = dataset.get_batch(start_index=0, batch_size=10)
 
 
 
12
 
13
+ # تهيئة نموذج الشبكة العصبية
14
+ nn_actor = NN_Actor(model=None, num_movers=10, num_neighbors_action=1)
15
 
16
+ # حساب مخرجات NN
17
+ nn_output = nn_actor(batch)
18
+ total_time_nn = nn_output['total_time']
19
+ arrival_times_nn = nn_output['arrival_times']
20
 
21
+ # استخدام Google OR-Tools لتقييم نفس البيانات
22
+ google_output = evaluate_google_model(dataset)
23
 
24
+ # طباعة النتائج للمقارنة
25
+ print("Arrival times (NN):")
26
+ print(arrival_times_nn)
27
 
28
+ print("\nAverage Total Time (NN Actor):", total_time_nn.mean().item())
29
+ print("Average Total Time (Google OR-Tools):", google_output.mean().item())
30
 
31
 
32
+ if __name__ == '__main__':
33
+ main()