Update inference.py
Browse files- inference.py +6 -4
inference.py
CHANGED
@@ -31,9 +31,9 @@ with open(params_path, 'r') as f:
|
|
31 |
|
32 |
device = params['device']
|
33 |
|
34 |
-
# --- تحميل النموذج ---
|
35 |
model = Model(
|
36 |
-
input_size =
|
37 |
embedding_size=params["embedding_size"],
|
38 |
decoder_input_size=params["decoder_input_size"]
|
39 |
)
|
@@ -52,8 +52,10 @@ delivery_lat = df['delivery_gps_lat'].to_numpy()
|
|
52 |
coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
|
53 |
coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
|
54 |
|
55 |
-
# --- تجهيز Batch ---
|
56 |
-
|
|
|
|
|
57 |
|
58 |
# --- تهيئة الممثل والـ NN ---
|
59 |
actor = Actor(
|
|
|
31 |
|
32 |
device = params['device']
|
33 |
|
34 |
+
# --- تحميل النموذج بعد ضبط input_size = 4 ---
|
35 |
model = Model(
|
36 |
+
input_size = 4,
|
37 |
embedding_size=params["embedding_size"],
|
38 |
decoder_input_size=params["decoder_input_size"]
|
39 |
)
|
|
|
52 |
coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
|
53 |
coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
|
54 |
|
55 |
+
# --- تجهيز Batch كـ Tuple لتجنب unpacking error في Actor ---
|
56 |
+
graph_data = {"coords": coords_tensor}
|
57 |
+
fleet_data = {"dummy": torch.tensor([0])} # يمكن تعديل هذا لاحقاً عند الحاجة
|
58 |
+
batch = (graph_data, fleet_data)
|
59 |
|
60 |
# --- تهيئة الممثل والـ NN ---
|
61 |
actor = Actor(
|