a-ragab-h-m commited on
Commit
7663567
·
verified ·
1 Parent(s): ab7a9fb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +46 -11
inference.py CHANGED
@@ -2,26 +2,36 @@ import torch
2
  from torch.utils.data import DataLoader
3
  import json
4
  import os
 
 
5
 
6
  from nets.model import Model
7
  from Actor.actor import Actor
8
  from dataloader import VRP_Dataset
9
 
10
  # --- تحميل الإعدادات ---
11
- with open('/data/params_saved.json', 'r') as f:
 
 
 
 
12
  params = json.load(f)
13
 
14
  # --- تعيين الجهاز ---
15
  device = params['device']
16
  dataset_path = params['dataset_path']
17
- input_size = None # سيتم تحديده بعد تحميل البيانات
18
 
19
- # --- تحميل نموذج مدرب ---
20
  model_path = "/data/model_state_dict.pt"
21
  if not os.path.exists(model_path):
22
  raise FileNotFoundError(f"Model not found at {model_path}")
23
 
24
- # --- إعداد بيانات عشوائية للاختبار ---
 
 
 
 
 
25
  inference_dataset = VRP_Dataset(
26
  size=1,
27
  num_nodes=params['num_nodes'],
@@ -29,7 +39,6 @@ inference_dataset = VRP_Dataset(
29
  path=dataset_path,
30
  device=device
31
  )
32
-
33
  input_size = inference_dataset.model_input_length()
34
 
35
  # --- تحميل النموذج ---
@@ -40,7 +49,7 @@ model = Model(
40
  )
41
  model.load_state_dict(torch.load(model_path, map_location=device))
42
 
43
- # --- تهيئة الممثل (Actor) والـ NN Actor ---
44
  actor = Actor(model=model,
45
  num_movers=params['num_movers'],
46
  num_neighbors_encoder=params['num_neighbors_encoder'],
@@ -52,8 +61,10 @@ actor.eval_mode()
52
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
53
  nn_actor.nearest_neighbors()
54
 
55
- # --- تنفيذ الاستدلال على دفعة واحدة ---
56
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
 
 
57
  for batch in dataloader:
58
  with torch.no_grad():
59
  actor.greedy_search()
@@ -63,7 +74,31 @@ for batch in dataloader:
63
  nn_output = nn_actor(batch)
64
  nn_time = nn_output['total_time'].item()
65
 
66
- print("\n===== INFERENCE RESULT =====")
67
- print(f"Actor Model Total Cost: {total_time:.4f}")
68
- print(f"Nearest Neighbor Cost : {nn_time:.4f}")
69
- print(f"Improvement over NN : {(nn_time - total_time) / nn_time * 100:.2f}%")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from torch.utils.data import DataLoader
3
  import json
4
  import os
5
+ import csv
6
+ from datetime import datetime
7
 
8
  from nets.model import Model
9
  from Actor.actor import Actor
10
  from dataloader import VRP_Dataset
11
 
12
  # --- تحميل الإعدادات ---
13
+ params_path = '/data/params_saved.json'
14
+ if not os.path.exists(params_path):
15
+ raise FileNotFoundError(f"Settings file not found at {params_path}")
16
+
17
+ with open(params_path, 'r') as f:
18
  params = json.load(f)
19
 
20
  # --- تعيين الجهاز ---
21
  device = params['device']
22
  dataset_path = params['dataset_path']
 
23
 
24
+ # --- مسار النموذج المحفوظ ---
25
  model_path = "/data/model_state_dict.pt"
26
  if not os.path.exists(model_path):
27
  raise FileNotFoundError(f"Model not found at {model_path}")
28
 
29
+ # --- تحضير مجلد النتائج ---
30
+ os.makedirs("/data", exist_ok=True)
31
+ txt_results_file = "/data/inference_results.txt"
32
+ csv_results_file = "/data/inference_results.csv"
33
+
34
+ # --- إعداد بيانات inference ---
35
  inference_dataset = VRP_Dataset(
36
  size=1,
37
  num_nodes=params['num_nodes'],
 
39
  path=dataset_path,
40
  device=device
41
  )
 
42
  input_size = inference_dataset.model_input_length()
43
 
44
  # --- تحميل النموذج ---
 
49
  )
50
  model.load_state_dict(torch.load(model_path, map_location=device))
51
 
52
+ # --- تهيئة الممثل والـ NN Actor ---
53
  actor = Actor(model=model,
54
  num_movers=params['num_movers'],
55
  num_neighbors_encoder=params['num_neighbors_encoder'],
 
61
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
62
  nn_actor.nearest_neighbors()
63
 
64
+ # --- استدلال وتخزين النتائج ---
65
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
66
+ output_lines = []
67
+
68
  for batch in dataloader:
69
  with torch.no_grad():
70
  actor.greedy_search()
 
74
  nn_output = nn_actor(batch)
75
  nn_time = nn_output['total_time'].item()
76
 
77
+ improvement = (nn_time - total_time) / nn_time * 100
78
+
79
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
80
+
81
+ result_text = (
82
+ "\n===== INFERENCE RESULT =====\n"
83
+ f"Time: {timestamp}\n"
84
+ f"Actor Model Total Cost: {total_time:.4f}\n"
85
+ f"Nearest Neighbor Cost : {nn_time:.4f}\n"
86
+ f"Improvement over NN : {improvement:.2f}%\n"
87
+ )
88
+ print(result_text)
89
+ output_lines.append(result_text)
90
+
91
+ # حفظ النتائج إلى CSV
92
+ write_header = not os.path.exists(csv_results_file)
93
+ with open(csv_results_file, 'a', newline='') as csvfile:
94
+ writer = csv.writer(csvfile)
95
+ if write_header:
96
+ writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
97
+ writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
98
+
99
+ # --- حفظ النتائج إلى ملف نصي ---
100
+ with open(txt_results_file, 'a') as f:
101
+ f.write("\n".join(output_lines))
102
+ f.write("\n=============================\n")
103
+
104
+ print(f"\n✅ Results saved to:\n- {txt_results_file}\n- {csv_results_file}")