a-ragab-h-m commited on
Commit
88ea10b
·
verified ·
1 Parent(s): bc77cc6

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +86 -78
inference.py CHANGED
@@ -1,9 +1,10 @@
1
  import torch
2
- from torch.utils.data import DataLoader
3
  import json
4
  import os
5
  import csv
6
  from datetime import datetime
 
7
  import numpy as np
8
 
9
  from nets.model import Model
@@ -12,6 +13,7 @@ from dataloader import VRP_Dataset
12
 
13
  # --- تحديد المجلد الآمن ---
14
  safe_data_dir = "/home/user/data"
 
15
 
16
  # --- تحميل الإعدادات ---
17
  params_path = os.path.join(safe_data_dir, 'params_saved.json')
@@ -21,41 +23,46 @@ if not os.path.exists(params_path):
21
  with open(params_path, 'r') as f:
22
  params = json.load(f)
23
 
24
- # --- تعيين الجهاز ---
25
  device = params['device']
26
- dataset_path = params['dataset_path']
27
 
28
- # --- مسار النموذج المحفوظ ---
29
  model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
30
  if not os.path.exists(model_path):
31
  raise FileNotFoundError(f"Model not found at {model_path}")
32
 
33
- # --- تحضير مجلد النتائج ---
34
- os.makedirs(safe_data_dir, exist_ok=True)
35
- txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
36
- csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
37
-
38
- # --- إعداد بيانات inference ---
39
- inference_dataset = VRP_Dataset(
40
- dataset_size=1,
41
- num_nodes=params['num_nodes'],
42
- num_depots=params['num_depots'],
43
- dataset_path=dataset_path,
44
- device=device
45
- )
46
-
47
- input_size = inference_dataset.model_input_length()
48
-
49
- # --- تحميل النموذج ---
50
  model = Model(
51
- input_size=input_size,
52
  embedding_size=params["embedding_size"],
53
  decoder_input_size=params["decoder_input_size"]
54
  )
55
  model.load_state_dict(torch.load(model_path, map_location=device))
56
  model.eval()
57
 
58
- # --- تهيئة الممثل والـ NN Actor ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  actor = Actor(model=model,
60
  num_movers=params['num_movers'],
61
  num_neighbors_encoder=params['num_neighbors_encoder'],
@@ -66,58 +73,59 @@ actor = Actor(model=model,
66
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
67
  nn_actor.nearest_neighbors()
68
 
69
- # --- استدلال وتخزين النتائج ---
70
- dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
71
- output_lines = []
72
- summary_text = ""
73
- input_summary = ""
74
-
75
- for batch in dataloader:
76
- with torch.no_grad():
77
- data = batch[0] # ← تعديل هام
78
- actor.greedy_search()
79
- actor_output = actor(batch)
80
- total_time = actor_output['total_time'].item()
81
-
82
- nn_output = nn_actor(batch)
83
- nn_time = nn_output['total_time'].item()
84
-
85
- improvement = (nn_time - total_time) / nn_time * 100
86
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
-
88
-
89
- # --- نص كامل للعرض والطباعة ---
90
- result_text = (
91
- "\n===== INFERENCE RESULT =====\n"
92
- f"Time: {timestamp}\n"
93
- f"Actor Model Total Cost: {total_time:.4f} units\n"
94
- f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
95
- f"Improvement over NN : {improvement:.2f}%\n"
96
- )
97
- print(result_text)
98
- output_lines.append(result_text)
99
-
100
- # --- للواجهة Gradio ---
101
- summary_text = (
102
- f"🕒 Time: {timestamp}\n"
103
- f"🚚 Actor Cost: {total_time:.4f} units\n"
104
- f"📍 NN Cost: {nn_time:.4f} units\n"
105
- f"📈 Improvement: {improvement:.2f}%\n\n"
106
- f"{input_summary}"
107
- )
108
-
109
- # --- CSV حفظ ---
110
- write_header = not os.path.exists(csv_results_file)
111
- with open(csv_results_file, 'a', newline='') as csvfile:
112
- writer = csv.writer(csvfile)
113
- if write_header:
114
- writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
115
- writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
116
-
117
- # --- حفظ النتائج إلى ملف نصي ---
118
- with open(txt_results_file, 'a') as f:
119
- f.write("\n".join(output_lines))
120
- f.write("\n=============================\n")
121
-
122
- # --- طباعة للـ UI
123
- print(f"\n🔍 Summary for UI:\n{summary_text}")
 
 
1
  import torch
2
+ from torch.utils.data import DataLoader, TensorDataset
3
  import json
4
  import os
5
  import csv
6
  from datetime import datetime
7
+ import pandas as pd
8
  import numpy as np
9
 
10
  from nets.model import Model
 
13
 
14
  # --- تحديد المجلد الآمن ---
15
  safe_data_dir = "/home/user/data"
16
+ orders_file = os.path.join(safe_data_dir, "orders.csv")
17
 
18
  # --- تحميل الإعدادات ---
19
  params_path = os.path.join(safe_data_dir, 'params_saved.json')
 
23
  with open(params_path, 'r') as f:
24
  params = json.load(f)
25
 
 
26
  device = params['device']
 
27
 
28
+ # --- تحميل النموذج ---
29
  model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
30
  if not os.path.exists(model_path):
31
  raise FileNotFoundError(f"Model not found at {model_path}")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  model = Model(
34
+ input_size=params["num_nodes"] * 4, # تعديل حسب تمثيل البيانات
35
  embedding_size=params["embedding_size"],
36
  decoder_input_size=params["decoder_input_size"]
37
  )
38
  model.load_state_dict(torch.load(model_path, map_location=device))
39
  model.eval()
40
 
41
+ # --- إعداد ملف النتائج ---
42
+ txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
43
+ csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
44
+ os.makedirs(safe_data_dir, exist_ok=True)
45
+
46
+ # --- قراءة ملف CSV ---
47
+ if not os.path.exists(orders_file):
48
+ raise FileNotFoundError(f"orders.csv not found at {orders_file}")
49
+
50
+ df = pd.read_csv(orders_file)
51
+
52
+ # --- تجهيز البيانات كـ Tensor ---
53
+ pickup_lng = df['lng'].to_numpy()
54
+ pickup_lat = df['lat'].to_numpy()
55
+ delivery_lng = df['delivery_gps_lng'].to_numpy()
56
+ delivery_lat = df['delivery_gps_lat'].to_numpy()
57
+
58
+ # دمج البيانات في مصفوفة واحدة (batch_size=1, N, features)
59
+ coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1) # shape (N, 4)
60
+ coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
61
+
62
+ # --- إنشاء Batch مدخل للنموذج ---
63
+ batch = [{"coords": coords_tensor}]
64
+
65
+ # --- إعداد الممثلين ---
66
  actor = Actor(model=model,
67
  num_movers=params['num_movers'],
68
  num_neighbors_encoder=params['num_neighbors_encoder'],
 
73
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
74
  nn_actor.nearest_neighbors()
75
 
76
+ # --- تشغيل النموذج ---
77
+ with torch.no_grad():
78
+ actor.greedy_search()
79
+ actor_output = actor(batch)
80
+ total_time = actor_output['total_time'].item()
81
+
82
+ nn_output = nn_actor(batch)
83
+ nn_time = nn_output['total_time'].item()
84
+
85
+ improvement = (nn_time - total_time) / nn_time * 100
86
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
+
88
+ # --- عرض جزء من البيانات المستخدمة ---
89
+ coords_preview = "\n".join([
90
+ f"Order {i}: P=({lng1:.4f},{lat1:.4f}) → D=({lng2:.4f},{lat2:.4f})"
91
+ for i, (lng1, lat1, lng2, lat2) in enumerate(coords[:5])
92
+ ])
93
+ if coords.shape[0] > 5:
94
+ coords_preview += f"\n... (showing 5 of {coords.shape[0]} orders)"
95
+
96
+ input_summary = f"📌 Input Orders Preview:\n{coords_preview}"
97
+
98
+ # --- نتيجة كاملة للطباعة ---
99
+ result_text = (
100
+ "\n===== INFERENCE RESULT =====\n"
101
+ f"Time: {timestamp}\n"
102
+ f"Actor Model Total Cost: {total_time:.4f} units\n"
103
+ f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
104
+ f"Improvement over NN : {improvement:.2f}%\n"
105
+ )
106
+ print(result_text)
107
+
108
+ # --- واجهة Gradio ---
109
+ summary_text = (
110
+ f"🕒 Time: {timestamp}\n"
111
+ f"🚚 Actor Cost: {total_time:.4f} units\n"
112
+ f"📍 NN Cost: {nn_time:.4f} units\n"
113
+ f"📈 Improvement: {improvement:.2f}%\n\n"
114
+ f"{input_summary}"
115
+ )
116
+
117
+ # --- حفظ إلى CSV ---
118
+ write_header = not os.path.exists(csv_results_file)
119
+ with open(csv_results_file, 'a', newline='') as csvfile:
120
+ writer = csv.writer(csvfile)
121
+ if write_header:
122
+ writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
123
+ writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
124
+
125
+ # --- حفظ نصي ---
126
+ with open(txt_results_file, 'a') as f:
127
+ f.write(result_text)
128
+ f.write("\n=============================\n")
129
+
130
+ # --- إخراج للواجهة
131
+ print(f"\n🔍 Summary for UI:\n{summary_text}")