a-ragab-h-m commited on
Commit
aa1a460
·
verified ·
1 Parent(s): 7bab133

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +6 -4
inference.py CHANGED
@@ -31,9 +31,9 @@ with open(params_path, 'r') as f:
31
 
32
  device = params['device']
33
 
34
- # --- تحميل النموذج ---
35
  model = Model(
36
- input_size = 5,
37
  embedding_size=params["embedding_size"],
38
  decoder_input_size=params["decoder_input_size"]
39
  )
@@ -52,8 +52,10 @@ delivery_lat = df['delivery_gps_lat'].to_numpy()
52
  coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
53
  coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
54
 
55
- # --- تجهيز Batch ---
56
- batch = [{"coords": coords_tensor}]
 
 
57
 
58
  # --- تهيئة الممثل والـ NN ---
59
  actor = Actor(
 
31
 
32
  device = params['device']
33
 
34
+ # --- تحميل النموذج بعد ضبط input_size = 4 ---
35
  model = Model(
36
+ input_size = 4,
37
  embedding_size=params["embedding_size"],
38
  decoder_input_size=params["decoder_input_size"]
39
  )
 
52
  coords = np.stack([pickup_lng, pickup_lat, delivery_lng, delivery_lat], axis=1)
53
  coords_tensor = torch.tensor(coords, dtype=torch.float32).unsqueeze(0).to(device)
54
 
55
+ # --- تجهيز Batch كـ Tuple لتجنب unpacking error في Actor ---
56
+ graph_data = {"coords": coords_tensor}
57
+ fleet_data = {"dummy": torch.tensor([0])} # يمكن تعديل هذا لاحقاً عند الحاجة
58
+ batch = (graph_data, fleet_data)
59
 
60
  # --- تهيئة الممثل والـ NN ---
61
  actor = Actor(