Update inference.py
Browse files- inference.py +46 -11
inference.py
CHANGED
@@ -2,26 +2,36 @@ import torch
|
|
2 |
from torch.utils.data import DataLoader
|
3 |
import json
|
4 |
import os
|
|
|
|
|
5 |
|
6 |
from nets.model import Model
|
7 |
from Actor.actor import Actor
|
8 |
from dataloader import VRP_Dataset
|
9 |
|
10 |
# --- تحميل الإعدادات ---
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
params = json.load(f)
|
13 |
|
14 |
# --- تعيين الجهاز ---
|
15 |
device = params['device']
|
16 |
dataset_path = params['dataset_path']
|
17 |
-
input_size = None # سيتم تحديده بعد تحميل البيانات
|
18 |
|
19 |
-
# ---
|
20 |
model_path = "/data/model_state_dict.pt"
|
21 |
if not os.path.exists(model_path):
|
22 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
23 |
|
24 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
25 |
inference_dataset = VRP_Dataset(
|
26 |
size=1,
|
27 |
num_nodes=params['num_nodes'],
|
@@ -29,7 +39,6 @@ inference_dataset = VRP_Dataset(
|
|
29 |
path=dataset_path,
|
30 |
device=device
|
31 |
)
|
32 |
-
|
33 |
input_size = inference_dataset.model_input_length()
|
34 |
|
35 |
# --- تحميل النموذج ---
|
@@ -40,7 +49,7 @@ model = Model(
|
|
40 |
)
|
41 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
42 |
|
43 |
-
# --- تهيئة الممثل
|
44 |
actor = Actor(model=model,
|
45 |
num_movers=params['num_movers'],
|
46 |
num_neighbors_encoder=params['num_neighbors_encoder'],
|
@@ -52,8 +61,10 @@ actor.eval_mode()
|
|
52 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
53 |
nn_actor.nearest_neighbors()
|
54 |
|
55 |
-
# ---
|
56 |
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
|
|
|
|
|
57 |
for batch in dataloader:
|
58 |
with torch.no_grad():
|
59 |
actor.greedy_search()
|
@@ -63,7 +74,31 @@ for batch in dataloader:
|
|
63 |
nn_output = nn_actor(batch)
|
64 |
nn_time = nn_output['total_time'].item()
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from torch.utils.data import DataLoader
|
3 |
import json
|
4 |
import os
|
5 |
+
import csv
|
6 |
+
from datetime import datetime
|
7 |
|
8 |
from nets.model import Model
|
9 |
from Actor.actor import Actor
|
10 |
from dataloader import VRP_Dataset
|
11 |
|
12 |
# --- تحميل الإعدادات ---
|
13 |
+
params_path = '/data/params_saved.json'
|
14 |
+
if not os.path.exists(params_path):
|
15 |
+
raise FileNotFoundError(f"Settings file not found at {params_path}")
|
16 |
+
|
17 |
+
with open(params_path, 'r') as f:
|
18 |
params = json.load(f)
|
19 |
|
20 |
# --- تعيين الجهاز ---
|
21 |
device = params['device']
|
22 |
dataset_path = params['dataset_path']
|
|
|
23 |
|
24 |
+
# --- مسار النموذج المحفوظ ---
|
25 |
model_path = "/data/model_state_dict.pt"
|
26 |
if not os.path.exists(model_path):
|
27 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
28 |
|
29 |
+
# --- تحضير مجلد النتائج ---
|
30 |
+
os.makedirs("/data", exist_ok=True)
|
31 |
+
txt_results_file = "/data/inference_results.txt"
|
32 |
+
csv_results_file = "/data/inference_results.csv"
|
33 |
+
|
34 |
+
# --- إعداد بيانات inference ---
|
35 |
inference_dataset = VRP_Dataset(
|
36 |
size=1,
|
37 |
num_nodes=params['num_nodes'],
|
|
|
39 |
path=dataset_path,
|
40 |
device=device
|
41 |
)
|
|
|
42 |
input_size = inference_dataset.model_input_length()
|
43 |
|
44 |
# --- تحميل النموذج ---
|
|
|
49 |
)
|
50 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
51 |
|
52 |
+
# --- تهيئة الممثل والـ NN Actor ---
|
53 |
actor = Actor(model=model,
|
54 |
num_movers=params['num_movers'],
|
55 |
num_neighbors_encoder=params['num_neighbors_encoder'],
|
|
|
61 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
62 |
nn_actor.nearest_neighbors()
|
63 |
|
64 |
+
# --- استدلال وتخزين النتائج ---
|
65 |
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
|
66 |
+
output_lines = []
|
67 |
+
|
68 |
for batch in dataloader:
|
69 |
with torch.no_grad():
|
70 |
actor.greedy_search()
|
|
|
74 |
nn_output = nn_actor(batch)
|
75 |
nn_time = nn_output['total_time'].item()
|
76 |
|
77 |
+
improvement = (nn_time - total_time) / nn_time * 100
|
78 |
+
|
79 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
80 |
+
|
81 |
+
result_text = (
|
82 |
+
"\n===== INFERENCE RESULT =====\n"
|
83 |
+
f"Time: {timestamp}\n"
|
84 |
+
f"Actor Model Total Cost: {total_time:.4f}\n"
|
85 |
+
f"Nearest Neighbor Cost : {nn_time:.4f}\n"
|
86 |
+
f"Improvement over NN : {improvement:.2f}%\n"
|
87 |
+
)
|
88 |
+
print(result_text)
|
89 |
+
output_lines.append(result_text)
|
90 |
+
|
91 |
+
# حفظ النتائج إلى CSV
|
92 |
+
write_header = not os.path.exists(csv_results_file)
|
93 |
+
with open(csv_results_file, 'a', newline='') as csvfile:
|
94 |
+
writer = csv.writer(csvfile)
|
95 |
+
if write_header:
|
96 |
+
writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
|
97 |
+
writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
|
98 |
+
|
99 |
+
# --- حفظ النتائج إلى ملف نصي ---
|
100 |
+
with open(txt_results_file, 'a') as f:
|
101 |
+
f.write("\n".join(output_lines))
|
102 |
+
f.write("\n=============================\n")
|
103 |
+
|
104 |
+
print(f"\n✅ Results saved to:\n- {txt_results_file}\n- {csv_results_file}")
|