a-ragab-h-m commited on
Commit
19355fa
ยท
verified ยท
1 Parent(s): 7724504

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +26 -9
inference.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import os
5
  import csv
6
  from datetime import datetime
 
7
 
8
  from nets.model import Model
9
  from Actor.actor import Actor
@@ -52,7 +53,7 @@ model = Model(
52
  decoder_input_size=params["decoder_input_size"]
53
  )
54
  model.load_state_dict(torch.load(model_path, map_location=device))
55
- model.eval() # โ† ู‡ุฐุง ู‡ูˆ ุงู„ุชุนุฏูŠู„ ุงู„ู…ู‡ู…
56
 
57
  # --- ุชู‡ูŠุฆุฉ ุงู„ู…ู…ุซู„ ูˆุงู„ู€ NN Actor ---
58
  actor = Actor(model=model,
@@ -62,7 +63,6 @@ actor = Actor(model=model,
62
  device=device,
63
  normalize=False)
64
 
65
-
66
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
67
  nn_actor.nearest_neighbors()
68
 
@@ -70,6 +70,7 @@ nn_actor.nearest_neighbors()
70
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
71
  output_lines = []
72
  summary_text = ""
 
73
 
74
  for batch in dataloader:
75
  with torch.no_grad():
@@ -83,23 +84,39 @@ for batch in dataloader:
83
  improvement = (nn_time - total_time) / nn_time * 100
84
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
85
 
 
 
 
 
 
 
 
 
 
 
 
86
  result_text = (
 
 
87
  "\n===== INFERENCE RESULT =====\n"
88
  f"Time: {timestamp}\n"
89
- f"Actor Model Total Cost: {total_time:.4f}\n"
90
- f"Nearest Neighbor Cost : {nn_time:.4f}\n"
91
  f"Improvement over NN : {improvement:.2f}%\n"
92
  )
93
  print(result_text)
94
  output_lines.append(result_text)
95
 
 
96
  summary_text = (
97
- f"Inference Time: {timestamp}\n"
98
- f"Actor Cost = {total_time:.4f} | NN Cost = {nn_time:.4f} "
99
- f"| Improvement = {improvement:.2f}%"
 
 
100
  )
101
 
102
- # ุญูุธ ุงู„ู†ุชุงุฆุฌ ุฅู„ู‰ CSV
103
  write_header = not os.path.exists(csv_results_file)
104
  with open(csv_results_file, 'a', newline='') as csvfile:
105
  writer = csv.writer(csvfile)
@@ -112,5 +129,5 @@ with open(txt_results_file, 'a') as f:
112
  f.write("\n".join(output_lines))
113
  f.write("\n=============================\n")
114
 
115
- # ู‡ุฐุง ุงู„ุณุทุฑ ู‡ูˆ ุงู„ุฌุฏูŠุฏ ู„ุนุฑุถ ุงู„ู†ุชุงุฆุฌ ููŠ ูˆุงุฌู‡ุฉ Gradio
116
  print(f"\n๐Ÿ” Summary for UI:\n{summary_text}")
 
4
  import os
5
  import csv
6
  from datetime import datetime
7
+ import numpy as np
8
 
9
  from nets.model import Model
10
  from Actor.actor import Actor
 
53
  decoder_input_size=params["decoder_input_size"]
54
  )
55
  model.load_state_dict(torch.load(model_path, map_location=device))
56
+ model.eval()
57
 
58
  # --- ุชู‡ูŠุฆุฉ ุงู„ู…ู…ุซู„ ูˆุงู„ู€ NN Actor ---
59
  actor = Actor(model=model,
 
63
  device=device,
64
  normalize=False)
65
 
 
66
  nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
67
  nn_actor.nearest_neighbors()
68
 
 
70
  dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
71
  output_lines = []
72
  summary_text = ""
73
+ input_summary = ""
74
 
75
  for batch in dataloader:
76
  with torch.no_grad():
 
84
  improvement = (nn_time - total_time) / nn_time * 100
85
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
86
 
87
+ # --- ู…ู„ุฎุต ุงู„ุจูŠุงู†ุงุช ---
88
+ coords = batch['coords'][0].cpu().numpy() # shape: (N, 2)
89
+ coords_preview = "\n".join(
90
+ [f"Node {i}: x={x:.3f}, y={y:.3f}" for i, (x, y) in enumerate(coords[:5])]
91
+ )
92
+ if coords.shape[0] > 5:
93
+ coords_preview += f"\n... (showing 5 of {coords.shape[0]} nodes)"
94
+
95
+ input_summary = f"๐Ÿ“Œ **Input Coordinates Preview:**\n{coords_preview}"
96
+
97
+ # --- ู†ุต ูƒุงู…ู„ ู„ู„ุนุฑุถ ูˆุงู„ุทุจุงุนุฉ ---
98
  result_text = (
99
+ "\n===== INPUT SUMMARY =====\n"
100
+ f"{coords_preview}\n"
101
  "\n===== INFERENCE RESULT =====\n"
102
  f"Time: {timestamp}\n"
103
+ f"Actor Model Total Cost: {total_time:.4f} units\n"
104
+ f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
105
  f"Improvement over NN : {improvement:.2f}%\n"
106
  )
107
  print(result_text)
108
  output_lines.append(result_text)
109
 
110
+ # --- ู„ู„ูˆุงุฌู‡ุฉ Gradio ---
111
  summary_text = (
112
+ f"๐Ÿ•’ Time: {timestamp}\n"
113
+ f"๐Ÿšš Actor Cost: {total_time:.4f} units\n"
114
+ f"๐Ÿ“ NN Cost: {nn_time:.4f} units\n"
115
+ f"๐Ÿ“ˆ Improvement: {improvement:.2f}%\n\n"
116
+ f"{input_summary}"
117
  )
118
 
119
+ # --- CSV ุญูุธ ---
120
  write_header = not os.path.exists(csv_results_file)
121
  with open(csv_results_file, 'a', newline='') as csvfile:
122
  writer = csv.writer(csvfile)
 
129
  f.write("\n".join(output_lines))
130
  f.write("\n=============================\n")
131
 
132
+ # --- ุทุจุงุนุฉ ู„ู„ู€ UI
133
  print(f"\n๐Ÿ” Summary for UI:\n{summary_text}")