vrp-shanghai-transformer / inference.py
a-ragab-h-m's picture
Update inference.py
7663567 verified
raw
history blame
3.49 kB
import torch
from torch.utils.data import DataLoader
import json
import os
import csv
from datetime import datetime
from nets.model import Model
from Actor.actor import Actor
from dataloader import VRP_Dataset
# --- تحميل الإعدادات ---
params_path = '/data/params_saved.json'
if not os.path.exists(params_path):
raise FileNotFoundError(f"Settings file not found at {params_path}")
with open(params_path, 'r') as f:
params = json.load(f)
# --- تعيين الجهاز ---
device = params['device']
dataset_path = params['dataset_path']
# --- مسار النموذج المحفوظ ---
model_path = "/data/model_state_dict.pt"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
# --- تحضير مجلد النتائج ---
os.makedirs("/data", exist_ok=True)
txt_results_file = "/data/inference_results.txt"
csv_results_file = "/data/inference_results.csv"
# --- إعداد بيانات inference ---
inference_dataset = VRP_Dataset(
size=1,
num_nodes=params['num_nodes'],
num_depots=params['num_depots'],
path=dataset_path,
device=device
)
input_size = inference_dataset.model_input_length()
# --- تحميل النموذج ---
model = Model(
input_size=input_size,
embedding_size=params["embedding_size"],
decoder_input_size=params["decoder_input_size"]
)
model.load_state_dict(torch.load(model_path, map_location=device))
# --- تهيئة الممثل والـ NN Actor ---
actor = Actor(model=model,
num_movers=params['num_movers'],
num_neighbors_encoder=params['num_neighbors_encoder'],
num_neighbors_action=params['num_neighbors_action'],
device=device,
normalize=False)
actor.eval_mode()
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
nn_actor.nearest_neighbors()
# --- استدلال وتخزين النتائج ---
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
output_lines = []
for batch in dataloader:
with torch.no_grad():
actor.greedy_search()
actor_output = actor(batch)
total_time = actor_output['total_time'].item()
nn_output = nn_actor(batch)
nn_time = nn_output['total_time'].item()
improvement = (nn_time - total_time) / nn_time * 100
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
result_text = (
"\n===== INFERENCE RESULT =====\n"
f"Time: {timestamp}\n"
f"Actor Model Total Cost: {total_time:.4f}\n"
f"Nearest Neighbor Cost : {nn_time:.4f}\n"
f"Improvement over NN : {improvement:.2f}%\n"
)
print(result_text)
output_lines.append(result_text)
# حفظ النتائج إلى CSV
write_header = not os.path.exists(csv_results_file)
with open(csv_results_file, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
if write_header:
writer.writerow(["Timestamp", "Actor Cost", "NN Cost", "Improvement (%)"])
writer.writerow([timestamp, f"{total_time:.4f}", f"{nn_time:.4f}", f"{improvement:.2f}"])
# --- حفظ النتائج إلى ملف نصي ---
with open(txt_results_file, 'a') as f:
f.write("\n".join(output_lines))
f.write("\n=============================\n")
print(f"\n✅ Results saved to:\n- {txt_results_file}\n- {csv_results_file}")