a-ragab-h-m commited on
Commit
17b50e6
·
verified ·
1 Parent(s): 25c9425

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +69 -0
inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ import json
4
+ import os
5
+
6
+ from nets.model import Model
7
+ from Actor.actor import Actor
8
+ from dataloader import VRP_Dataset
9
+
10
+ # --- تحميل الإعدادات ---
11
+ with open('/data/params_saved.json', 'r') as f:
12
+ params = json.load(f)
13
+
14
+ # --- تعيين الجهاز ---
15
+ device = params['device']
16
+ dataset_path = params['dataset_path']
17
+ input_size = None # سيتم تحديده بعد تحميل البيانات
18
+
19
+ # --- تحميل نموذج مدرب ---
20
+ model_path = "/data/model_state_dict.pt"
21
+ if not os.path.exists(model_path):
22
+ raise FileNotFoundError(f"Model not found at {model_path}")
23
+
24
+ # --- إعداد بيانات عشوائية للاختبار ---
25
+ inference_dataset = VRP_Dataset(
26
+ size=1,
27
+ num_nodes=params['num_nodes'],
28
+ num_depots=params['num_depots'],
29
+ path=dataset_path,
30
+ device=device
31
+ )
32
+
33
+ input_size = inference_dataset.model_input_length()
34
+
35
+ # --- تحميل النموذج ---
36
+ model = Model(
37
+ input_size=input_size,
38
+ embedding_size=params["embedding_size"],
39
+ decoder_input_size=params["decoder_input_size"]
40
+ )
41
+ model.load_state_dict(torch.load(model_path, map_location=device))
42
+
43
+ # --- تهيئة الممثل (Actor) والـ NN Actor ---
44
+ actor = Actor(model=model,
45
+ num_movers=params['num_movers'],
46
+ num_neighbors_encoder=params['num_neighbors_encoder'],
47
+ num_neighbors_action=params['num_neighbors_action'],
48
+ device=device,
49
+ normalize=False)
50
+ actor.eval_mode()
51
+
52
+ nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
53
+ nn_actor.nearest_neighbors()
54
+
55
+ # --- تنفيذ الاستدلال على دفعة واحدة ---
56
+ dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
57
+ for batch in dataloader:
58
+ with torch.no_grad():
59
+ actor.greedy_search()
60
+ actor_output = actor(batch)
61
+ total_time = actor_output['total_time'].item()
62
+
63
+ nn_output = nn_actor(batch)
64
+ nn_time = nn_output['total_time'].item()
65
+
66
+ print("\n===== INFERENCE RESULT =====")
67
+ print(f"Actor Model Total Cost: {total_time:.4f}")
68
+ print(f"Nearest Neighbor Cost : {nn_time:.4f}")
69
+ print(f"Improvement over NN : {(nn_time - total_time) / nn_time * 100:.2f}%")