a-ragab-h-m commited on
Commit
9efbcc9
·
verified ·
1 Parent(s): 7498d21

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +9 -3
inference.py CHANGED
@@ -35,7 +35,6 @@ txt_results_file = os.path.join(safe_data_dir, "inference_results.txt")
35
  csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
36
 
37
  # --- إعداد بيانات inference ---
38
-
39
  inference_dataset = VRP_Dataset(
40
  dataset_size=1,
41
  num_nodes=params['num_nodes'],
@@ -69,6 +68,7 @@ nn_actor.nearest_neighbors()
69
  # --- استدلال وتخزين النتائج ---
70
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
71
  output_lines = []
 
72
 
73
  for batch in dataloader:
74
  with torch.no_grad():
@@ -80,7 +80,6 @@ for batch in dataloader:
80
  nn_time = nn_output['total_time'].item()
81
 
82
  improvement = (nn_time - total_time) / nn_time * 100
83
-
84
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
85
 
86
  result_text = (
@@ -93,6 +92,12 @@ for batch in dataloader:
93
  print(result_text)
94
  output_lines.append(result_text)
95
 
 
 
 
 
 
 
96
  # حفظ النتائج إلى CSV
97
  write_header = not os.path.exists(csv_results_file)
98
  with open(csv_results_file, 'a', newline='') as csvfile:
@@ -106,4 +111,5 @@ with open(txt_results_file, 'a') as f:
106
  f.write("\n".join(output_lines))
107
  f.write("\n=============================\n")
108
 
109
- print(f"\n✅ Results saved to:\n- {txt_results_file}\n- {csv_results_file}")
 
 
35
  csv_results_file = os.path.join(safe_data_dir, "inference_results.csv")
36
 
37
  # --- إعداد بيانات inference ---
 
38
  inference_dataset = VRP_Dataset(
39
  dataset_size=1,
40
  num_nodes=params['num_nodes'],
 
68
  # --- استدلال وتخزين النتائج ---
69
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
70
  output_lines = []
71
+ summary_text = ""
72
 
73
  for batch in dataloader:
74
  with torch.no_grad():
 
80
  nn_time = nn_output['total_time'].item()
81
 
82
  improvement = (nn_time - total_time) / nn_time * 100
 
83
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
84
 
85
  result_text = (
 
92
  print(result_text)
93
  output_lines.append(result_text)
94
 
95
+ summary_text = (
96
+ f"Inference Time: {timestamp}\n"
97
+ f"Actor Cost = {total_time:.4f} | NN Cost = {nn_time:.4f} "
98
+ f"| Improvement = {improvement:.2f}%"
99
+ )
100
+
101
  # حفظ النتائج إلى CSV
102
  write_header = not os.path.exists(csv_results_file)
103
  with open(csv_results_file, 'a', newline='') as csvfile:
 
111
  f.write("\n".join(output_lines))
112
  f.write("\n=============================\n")
113
 
114
+ # هذا السطر هو الجديد لعرض النتائج في واجهة Gradio
115
+ print(f"\n🔍 Summary for UI:\n{summary_text}")