Update inference.py
Browse files- inference.py +26 -9
inference.py
CHANGED
@@ -4,6 +4,7 @@ import json
|
|
4 |
import os
|
5 |
import csv
|
6 |
from datetime import datetime
|
|
|
7 |
|
8 |
from nets.model import Model
|
9 |
from Actor.actor import Actor
|
@@ -52,7 +53,7 @@ model = Model(
|
|
52 |
decoder_input_size=params["decoder_input_size"]
|
53 |
)
|
54 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
55 |
-
model.eval()
|
56 |
|
57 |
# --- ุชููุฆุฉ ุงูู
ู
ุซู ูุงูู NN Actor ---
|
58 |
actor = Actor(model=model,
|
@@ -62,7 +63,6 @@ actor = Actor(model=model,
|
|
62 |
device=device,
|
63 |
normalize=False)
|
64 |
|
65 |
-
|
66 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
67 |
nn_actor.nearest_neighbors()
|
68 |
|
@@ -70,6 +70,7 @@ nn_actor.nearest_neighbors()
|
|
70 |
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
|
71 |
output_lines = []
|
72 |
summary_text = ""
|
|
|
73 |
|
74 |
for batch in dataloader:
|
75 |
with torch.no_grad():
|
@@ -83,23 +84,39 @@ for batch in dataloader:
|
|
83 |
improvement = (nn_time - total_time) / nn_time * 100
|
84 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
result_text = (
|
|
|
|
|
87 |
"\n===== INFERENCE RESULT =====\n"
|
88 |
f"Time: {timestamp}\n"
|
89 |
-
f"Actor Model Total Cost: {total_time:.4f}\n"
|
90 |
-
f"Nearest Neighbor Cost : {nn_time:.4f}\n"
|
91 |
f"Improvement over NN : {improvement:.2f}%\n"
|
92 |
)
|
93 |
print(result_text)
|
94 |
output_lines.append(result_text)
|
95 |
|
|
|
96 |
summary_text = (
|
97 |
-
f"
|
98 |
-
f"Actor Cost
|
99 |
-
f"
|
|
|
|
|
100 |
)
|
101 |
|
102 |
-
#
|
103 |
write_header = not os.path.exists(csv_results_file)
|
104 |
with open(csv_results_file, 'a', newline='') as csvfile:
|
105 |
writer = csv.writer(csvfile)
|
@@ -112,5 +129,5 @@ with open(txt_results_file, 'a') as f:
|
|
112 |
f.write("\n".join(output_lines))
|
113 |
f.write("\n=============================\n")
|
114 |
|
115 |
-
#
|
116 |
print(f"\n๐ Summary for UI:\n{summary_text}")
|
|
|
4 |
import os
|
5 |
import csv
|
6 |
from datetime import datetime
|
7 |
+
import numpy as np
|
8 |
|
9 |
from nets.model import Model
|
10 |
from Actor.actor import Actor
|
|
|
53 |
decoder_input_size=params["decoder_input_size"]
|
54 |
)
|
55 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
56 |
+
model.eval()
|
57 |
|
58 |
# --- ุชููุฆุฉ ุงูู
ู
ุซู ูุงูู NN Actor ---
|
59 |
actor = Actor(model=model,
|
|
|
63 |
device=device,
|
64 |
normalize=False)
|
65 |
|
|
|
66 |
nn_actor = Actor(model=None, num_movers=1, num_neighbors_action=1, device=device)
|
67 |
nn_actor.nearest_neighbors()
|
68 |
|
|
|
70 |
dataloader = DataLoader(inference_dataset, batch_size=1, collate_fn=inference_dataset.collate)
|
71 |
output_lines = []
|
72 |
summary_text = ""
|
73 |
+
input_summary = ""
|
74 |
|
75 |
for batch in dataloader:
|
76 |
with torch.no_grad():
|
|
|
84 |
improvement = (nn_time - total_time) / nn_time * 100
|
85 |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
86 |
|
87 |
+
# --- ู
ูุฎุต ุงูุจูุงูุงุช ---
|
88 |
+
coords = batch['coords'][0].cpu().numpy() # shape: (N, 2)
|
89 |
+
coords_preview = "\n".join(
|
90 |
+
[f"Node {i}: x={x:.3f}, y={y:.3f}" for i, (x, y) in enumerate(coords[:5])]
|
91 |
+
)
|
92 |
+
if coords.shape[0] > 5:
|
93 |
+
coords_preview += f"\n... (showing 5 of {coords.shape[0]} nodes)"
|
94 |
+
|
95 |
+
input_summary = f"๐ **Input Coordinates Preview:**\n{coords_preview}"
|
96 |
+
|
97 |
+
# --- ูุต ูุงู
ู ููุนุฑุถ ูุงูุทุจุงุนุฉ ---
|
98 |
result_text = (
|
99 |
+
"\n===== INPUT SUMMARY =====\n"
|
100 |
+
f"{coords_preview}\n"
|
101 |
"\n===== INFERENCE RESULT =====\n"
|
102 |
f"Time: {timestamp}\n"
|
103 |
+
f"Actor Model Total Cost: {total_time:.4f} units\n"
|
104 |
+
f"Nearest Neighbor Cost : {nn_time:.4f} units\n"
|
105 |
f"Improvement over NN : {improvement:.2f}%\n"
|
106 |
)
|
107 |
print(result_text)
|
108 |
output_lines.append(result_text)
|
109 |
|
110 |
+
# --- ูููุงุฌูุฉ Gradio ---
|
111 |
summary_text = (
|
112 |
+
f"๐ Time: {timestamp}\n"
|
113 |
+
f"๐ Actor Cost: {total_time:.4f} units\n"
|
114 |
+
f"๐ NN Cost: {nn_time:.4f} units\n"
|
115 |
+
f"๐ Improvement: {improvement:.2f}%\n\n"
|
116 |
+
f"{input_summary}"
|
117 |
)
|
118 |
|
119 |
+
# --- CSV ุญูุธ ---
|
120 |
write_header = not os.path.exists(csv_results_file)
|
121 |
with open(csv_results_file, 'a', newline='') as csvfile:
|
122 |
writer = csv.writer(csvfile)
|
|
|
129 |
f.write("\n".join(output_lines))
|
130 |
f.write("\n=============================\n")
|
131 |
|
132 |
+
# --- ุทุจุงุนุฉ ููู UI
|
133 |
print(f"\n๐ Summary for UI:\n{summary_text}")
|