a-ragab-h-m commited on
Commit
a370b55
·
verified ·
1 Parent(s): da396b7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +31 -38
inference.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- from torch.utils.data import DataLoader, TensorDataset
3
  import json
4
  import os
5
  import csv
@@ -9,71 +8,67 @@ import numpy as np
9
 
10
  from nets.model import Model
11
  from Actor.actor import Actor
12
- from dataloader import VRP_Dataset
13
 
14
- # --- تحديد المجلد الآمن ---
15
  safe_data_dir = "/home/user/data"
16
  orders_file = os.path.join(safe_data_dir, "orders.csv")
17
-
18
- # --- تحميل الإعدادات ---
19
  params_path = os.path.join(safe_data_dir, 'params_saved.json')
 
 
 
 
 
20
  if not os.path.exists(params_path):
21
  raise FileNotFoundError(f"Settings file not found at {params_path}")
 
 
 
 
22
 
 
23
  with open(params_path, 'r') as f:
24
  params = json.load(f)
25
 
26
  device = params['device']
27
 
28
  # --- تحميل النموذج ---
29
- model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
30
- if not os.path.exists(model_path):
31
- raise FileNotFoundError(f"Model not found at {model_path}")
32
-
33
  model = Model(
34
- input_size = 5, # تعديل حسب تمثيل البيانات
35
  embedding_size=params["embedding_size"],
36
  decoder_input_size=params["decoder_input_size"]
37
  )
38
  model.load_state_dict(torch.load(model_path, map_location=device))
39
  model.eval()
40
 
41
- # --- إعداد ملف النتائج ---
42
- txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
43
- csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
44
- os.makedirs(safe_data_dir, exist_ok=True)
45
-
46
- # --- قراءة ملف CSV ---
47
- if not os.path.exists(orders_file):
48
- raise FileNotFoundError(f"orders.csv not found at {orders_file}")
49
-
50
  df = pd.read_csv(orders_file)
51
 
52
- # --- تجهيز البيانات كـ Tensor ---
53
  pickup_lng = df['lng'].to_numpy()
54
  pickup_lat = df['lat'].to_numpy()
55
  delivery_lng = df['delivery_gps_lng'].to_numpy()
56
  delivery_lat = df['delivery_gps_lat'].to_numpy()
57
 
58
- # دمج البيانات في مصفوفة واحدة (batch_size=1, N, features)
59
- coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1) # shape (N, 4)
60
  coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
61
 
62
- # --- إنشاء Batch مدخل للنموذج ---
63
  batch = [{"coords": coords_tensor}]
64
 
65
- # --- إعداد الممثلين ---
66
- actor = Actor(model=model,
67
- num_movers=params['num_movers'],
68
- num_neighbors_encoder=params['num_neighbors_encoder'],
69
- num_neighbors_action=params['num_neighbors_action'],
70
- device=device,
71
- normalize=False)
 
 
72
 
73
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
74
  nn_actor.nearest_neighbors()
75
 
76
- # --- تشغيل النموذج ---
77
  with torch.no_grad():
78
  actor.greedy_search()
79
  actor_output = actor(batch)
@@ -85,7 +80,7 @@ with torch.no_grad():
85
  improvement = (nn_time - total_time) / nn_time * 100
86
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
87
 
88
- # --- عرض جزء من البيانات المستخدمة ---
89
  coords_preview = "\n".join([
90
  f"Order {i}: P=({lng1:.4f},{lat1:.4f}) → D=({lng2:.4f},{lat2:.4f})"
91
  for i, (lng1, lat1, lng2, lat2) in enumerate(coords[:5])
@@ -95,7 +90,7 @@ with torch.no_grad():
95
 
96
  input_summary = f"📌 Input Orders Preview:\n{coords_preview}"
97
 
98
- # --- نتيجة كاملة للطباعة ---
99
  result_text = (
100
  "\n===== INFERENCE RESULT =====\n"
101
  f"Time: {timestamp}\n"
@@ -105,7 +100,7 @@ with torch.no_grad():
105
  )
106
  print(result_text)
107
 
108
- # --- واجهة Gradio ---
109
  summary_text = (
110
  f"🕒 Time: {timestamp}\n"
111
  f"🚚 Actor Cost: {total_time:.4f} units\n"
@@ -113,8 +108,9 @@ with torch.no_grad():
113
  f"📈 Improvement: {improvement:.2f}%\n\n"
114
  f"{input_summary}"
115
  )
 
116
 
117
- # --- ��فظ إلى CSV ---
118
  write_header = not os.path.exists(csv_results_file)
119
  with open(csv_results_file, 'a', newline='') as csvfile:
120
  writer = csv.writer(csvfile)
@@ -126,6 +122,3 @@ with torch.no_grad():
126
  with open(txt_results_file, 'a') as f:
127
  f.write(result_text)
128
  f.write("\n=============================\n")
129
-
130
- # --- إخراج للواجهة
131
- print(f"\n🔍 Summary for UI:\n{summary_text}")
 
1
  import torch
 
2
  import json
3
  import os
4
  import csv
 
8
 
9
  from nets.model import Model
10
  from Actor.actor import Actor
 
11
 
12
+ # --- إعداد المسارات ---
13
  safe_data_dir = "/home/user/data"
14
  orders_file = os.path.join(safe_data_dir, "orders.csv")
 
 
15
  params_path = os.path.join(safe_data_dir, 'params_saved.json')
16
+ model_path = os.path.join(safe_data_dir, "model_state_dict.pt")
17
+ txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
18
+ csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
19
+
20
+ # --- التحقق من الملفات ---
21
  if not os.path.exists(params_path):
22
  raise FileNotFoundError(f"Settings file not found at {params_path}")
23
+ if not os.path.exists(model_path):
24
+ raise FileNotFoundError(f"Model not found at {model_path}")
25
+ if not os.path.exists(orders_file):
26
+ raise FileNotFoundError(f"orders.csv not found at {orders_file}")
27
 
28
+ # --- تحميل الإعدادات ---
29
  with open(params_path, 'r') as f:
30
  params = json.load(f)
31
 
32
  device = params['device']
33
 
34
  # --- تحميل النموذج ---
 
 
 
 
35
  model = Model(
36
+ input_size=4, # استخدام 4 خصائص فقط (lng1, lat1, lng2, lat2)
37
  embedding_size=params["embedding_size"],
38
  decoder_input_size=params["decoder_input_size"]
39
  )
40
  model.load_state_dict(torch.load(model_path, map_location=device))
41
  model.eval()
42
 
43
+ # --- قراءة بيانات الطلبات ---
 
 
 
 
 
 
 
 
44
  df = pd.read_csv(orders_file)
45
 
46
+ # --- استخراج الإحداثيات ---
47
  pickup_lng = df['lng'].to_numpy()
48
  pickup_lat = df['lat'].to_numpy()
49
  delivery_lng = df['delivery_gps_lng'].to_numpy()
50
  delivery_lat = df['delivery_gps_lat'].to_numpy()
51
 
52
+ coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
 
53
  coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
54
 
55
+ # --- تجهيز Batch ---
56
  batch = [{"coords": coords_tensor}]
57
 
58
+ # --- تهيئة الممثل والـ NN ---
59
+ actor = Actor(
60
+ model=model,
61
+ num_movers=params['num_movers'],
62
+ num_neighbors_encoder=params['num_neighbors_encoder'],
63
+ num_neighbors_action=params['num_neighbors_action'],
64
+ device=device,
65
+ normalize=False
66
+ )
67
 
68
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
69
  nn_actor.nearest_neighbors()
70
 
71
+ # --- تنفيذ الاستدلال ---
72
  with torch.no_grad():
73
  actor.greedy_search()
74
  actor_output = actor(batch)
 
80
  improvement = (nn_time - total_time) / nn_time * 100
81
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
82
 
83
+ # --- عرض ملخص الإدخالات ---
84
  coords_preview = "\n".join([
85
  f"Order {i}: P=({lng1:.4f},{lat1:.4f}) → D=({lng2:.4f},{lat2:.4f})"
86
  for i, (lng1, lat1, lng2, lat2) in enumerate(coords[:5])
 
90
 
91
  input_summary = f"📌 Input Orders Preview:\n{coords_preview}"
92
 
93
+ # --- نتيجة مفصلة للطباعة ---
94
  result_text = (
95
  "\n===== INFERENCE RESULT =====\n"
96
  f"Time: {timestamp}\n"
 
100
  )
101
  print(result_text)
102
 
103
+ # --- ملخص للواجهة ---
104
  summary_text = (
105
  f"🕒 Time: {timestamp}\n"
106
  f"🚚 Actor Cost: {total_time:.4f} units\n"
 
108
  f"📈 Improvement: {improvement:.2f}%\n\n"
109
  f"{input_summary}"
110
  )
111
+ print(f"\n🔍 Summary for UI:\n{summary_text}")
112
 
113
+ # --- حفظ النتائج CSV ---
114
  write_header = not os.path.exists(csv_results_file)
115
  with open(csv_results_file, 'a', newline='') as csvfile:
116
  writer = csv.writer(csvfile)
 
122
  with open(txt_results_file, 'a') as f:
123
  f.write(result_text)
124
  f.write("\n=============================\n")