Update inference.py
Browse files- inference.py +86 -78
inference.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import DataLoader
|
3 |
import json
|
4 |
import os
|
5 |
import csv
|
6 |
from datetime import datetime
|
|
|
7 |
import numpy as np
|
8 |
|
9 |
from nets.model import Model
|
@@ -12,6 +13,7 @@ from dataloader import VRP_Dataset
|
|
12 |
|
13 |
# --- تحديد المجلد الآمن ---
|
14 |
safe_data_dir = "/home/user/data"
|
|
|
15 |
|
16 |
# --- تحميل الإعدادات ---
|
17 |
params_path = os.path.join(safe_data_dir, 'params_saved.json')
|
@@ -21,41 +23,46 @@ if not os.path.exists(params_path):
|
|
21 |
with open(params_path, 'r') as f:
|
22 |
params = json.load(f)
|
23 |
|
24 |
-
# --- تعيين الجهاز ---
|
25 |
device = params['device']
|
26 |
-
dataset_path = params['dataset_path']
|
27 |
|
28 |
-
# ---
|
29 |
model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
|
30 |
if not os.path.exists(model_path):
|
31 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
32 |
|
33 |
-
# --- تحضير مجلد النتائج ---
|
34 |
-
os.makedirs(safe_data_dir, exist_ok=True)
|
35 |
-
txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
|
36 |
-
csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
|
37 |
-
|
38 |
-
# --- إعداد بيانات inference ---
|
39 |
-
inference_dataset = VRP_Dataset(
|
40 |
-
dataset_size=1,
|
41 |
-
num_nodes=params['num_nodes'],
|
42 |
-
num_depots=params['num_depots'],
|
43 |
-
dataset_path=dataset_path,
|
44 |
-
device=device
|
45 |
-
)
|
46 |
-
|
47 |
-
input_size = inference_dataset.model_input_length()
|
48 |
-
|
49 |
-
# --- تحميل النموذج ---
|
50 |
model = Model(
|
51 |
-
input_size=
|
52 |
embedding_size=params["embedding_size"],
|
53 |
decoder_input_size=params["decoder_input_size"]
|
54 |
)
|
55 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
56 |
model.eval()
|
57 |
|
58 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
actor = Actor(model=model,
|
60 |
num_movers=params['num_movers'],
|
61 |
num_neighbors_encoder=params['num_neighbors_encoder'],
|
@@ -66,58 +73,59 @@ actor = Actor(model=model,
|
|
66 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
67 |
nn_actor.nearest_neighbors()
|
68 |
|
69 |
-
# ---
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
writer.writerow([
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
1 |
import torch
|
2 |
+
from torch.utils.data import DataLoader, TensorDataset
|
3 |
import json
|
4 |
import os
|
5 |
import csv
|
6 |
from datetime import datetime
|
7 |
+
import pandas as pd
|
8 |
import numpy as np
|
9 |
|
10 |
from nets.model import Model
|
|
|
13 |
|
14 |
# --- تحديد المجلد الآمن ---
|
15 |
safe_data_dir = "/home/user/data"
|
16 |
+
orders_file = os.path.join(safe_data_dir, "orders.csv")
|
17 |
|
18 |
# --- تحميل الإعدادات ---
|
19 |
params_path = os.path.join(safe_data_dir, 'params_saved.json')
|
|
|
23 |
with open(params_path, 'r') as f:
|
24 |
params = json.load(f)
|
25 |
|
|
|
26 |
device = params['device']
|
|
|
27 |
|
28 |
+
# --- تحميل النموذج ---
|
29 |
model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
|
30 |
if not os.path.exists(model_path):
|
31 |
raise FileNotFoundError(f"Model not found at {model_path}")
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
model = Model(
|
34 |
+
input_size=params["num_nodes"] * 4, # تعديل حسب تمثيل البيانات
|
35 |
embedding_size=params["embedding_size"],
|
36 |
decoder_input_size=params["decoder_input_size"]
|
37 |
)
|
38 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
39 |
model.eval()
|
40 |
|
41 |
+
# --- إعداد ملف النتائج ---
|
42 |
+
txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
|
43 |
+
csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
|
44 |
+
os.makedirs(safe_data_dir, exist_ok=True)
|
45 |
+
|
46 |
+
# --- قراءة ملف CSV ---
|
47 |
+
if not os.path.exists(orders_file):
|
48 |
+
raise FileNotFoundError(f"orders.csv not found at {orders_file}")
|
49 |
+
|
50 |
+
df = pd.read_csv(orders_file)
|
51 |
+
|
52 |
+
# --- تجهيز البيانات كـ Tensor ---
|
53 |
+
pickup_lng = df['lng'].to_numpy()
|
54 |
+
pickup_lat = df['lat'].to_numpy()
|
55 |
+
delivery_lng = df['delivery_gps_lng'].to_numpy()
|
56 |
+
delivery_lat = df['delivery_gps_lat'].to_numpy()
|
57 |
+
|
58 |
+
# دمج البيانات في مصفوفة واحدة (batch_size=1, N, features)
|
59 |
+
coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1) # shape (N, 4)
|
60 |
+
coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
|
61 |
+
|
62 |
+
# --- إنشاء Batch مدخل للنموذج ---
|
63 |
+
batch = [{"coords": coords_tensor}]
|
64 |
+
|
65 |
+
# --- إعداد الممثلين ---
|
66 |
actor = Actor(model=model,
|
67 |
num_movers=params['num_movers'],
|
68 |
num_neighbors_encoder=params['num_neighbors_encoder'],
|
|
|
73 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
74 |
nn_actor.nearest_neighbors()
|
75 |
|
76 |
+
# --- تشغيل النموذج ---
|
77 |
+
with torch.no_grad():
|
78 |
+
actor.greedy_search()
|
79 |
+
actor_output = actor(batch)
|
80 |
+
total_time = actor_output['total_time'].item()
|
81 |
+
|
82 |
+
nn_output = nn_actor(batch)
|
83 |
+
nn_time = nn_output['total_time'].item()
|
84 |
+
|
85 |
+
improvement = (nn_time - total_time) / nn_time * 100
|
86 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
87 |
+
|
88 |
+
# --- عرض جزء من البيانات المستخدمة ---
|
89 |
+
coords_preview = "\n".join([
|
90 |
+
f"Order {i}: P=({lng1:.4f},{lat1:.4f}) → D=({lng2:.4f},{lat2:.4f})"
|
91 |
+
for i, (lng1, lat1, lng2, lat2) in enumerate(coords[:5])
|
92 |
+
])
|
93 |
+
if coords.shape[0] > 5:
|
94 |
+
coords_preview += f"\n... (showing 5 of {coords.shape[0]} orders)"
|
95 |
+
|
96 |
+
input_summary = f"📌 Input Orders Preview:\n{coords_preview}"
|
97 |
+
|
98 |
+
# --- نتيجة كاملة للطباعة ---
|
99 |
+
result_text = (
|
100 |
+
"\n===== INFERENCE RESULT =====\n"
|
101 |
+
f"Time: {timestamp}\n"
|
102 |
+
f"Actor Model Total Cost: {total_time:.4f} units\n"
|
103 |
+
f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
|
104 |
+
f"Improvement over NN : {improvement:.2f}%\n"
|
105 |
+
)
|
106 |
+
print(result_text)
|
107 |
+
|
108 |
+
# --- واجهة Gradio ---
|
109 |
+
summary_text = (
|
110 |
+
f"🕒 Time: {timestamp}\n"
|
111 |
+
f"🚚 Actor Cost: {total_time:.4f} units\n"
|
112 |
+
f"📍 NN Cost: {nn_time:.4f} units\n"
|
113 |
+
f"📈 Improvement: {improvement:.2f}%\n\n"
|
114 |
+
f"{input_summary}"
|
115 |
+
)
|
116 |
+
|
117 |
+
# --- حفظ إلى CSV ---
|
118 |
+
write_header = not os.path.exists(csv_results_file)
|
119 |
+
with open(csv_results_file, 'a', newline='') as csvfile:
|
120 |
+
writer = csv.writer(csvfile)
|
121 |
+
if write_header:
|
122 |
+
writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
|
123 |
+
writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
|
124 |
+
|
125 |
+
# --- حفظ نصي ---
|
126 |
+
with open(txt_results_file, 'a') as f:
|
127 |
+
f.write(result_text)
|
128 |
+
f.write("\n=============================\n")
|
129 |
+
|
130 |
+
# --- إخراج للواجهة
|
131 |
+
print(f"\n🔍 Summary for UI:\n{summary_text}")
|