Spaces:
Sleeping
Sleeping
Andy Lee
commited on
Commit
·
f2b6ded
1
Parent(s):
6fda968
feat: app.py
Browse files
app.py
CHANGED
@@ -6,94 +6,95 @@ from io import BytesIO
|
|
6 |
from PIL import Image
|
7 |
from typing import Dict, List, Any
|
8 |
|
9 |
-
#
|
10 |
from geo_bot import (
|
11 |
GeoBot,
|
12 |
AGENT_PROMPT_TEMPLATE,
|
13 |
BENCHMARK_PROMPT,
|
14 |
-
)
|
15 |
from benchmark import MapGuesserBenchmark
|
16 |
from config import MODELS_CONFIG, DATA_PATHS, SUCCESS_THRESHOLD_KM
|
17 |
from langchain_openai import ChatOpenAI
|
18 |
from langchain_anthropic import ChatAnthropic
|
19 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
20 |
|
21 |
-
# ---
|
22 |
st.set_page_config(page_title="MapCrunch AI Agent", layout="wide")
|
23 |
st.title("🗺️ MapCrunch AI Agent")
|
24 |
-
st.caption(
|
|
|
|
|
25 |
|
26 |
-
# --- Sidebar
|
27 |
with st.sidebar:
|
28 |
-
st.header("⚙️
|
29 |
|
30 |
-
#
|
31 |
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY", "")
|
32 |
os.environ["ANTHROPIC_API_KEY"] = st.secrets.get("ANTHROPIC_API_KEY", "")
|
33 |
-
# 添加其他你可能需要的API密钥
|
34 |
# os.environ['GOOGLE_API_KEY'] = st.secrets.get("GOOGLE_API_KEY", "")
|
35 |
|
36 |
-
model_choice = st.selectbox("
|
37 |
steps_per_sample = st.slider(
|
38 |
-
"
|
39 |
)
|
40 |
|
41 |
-
#
|
42 |
try:
|
43 |
with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
|
44 |
golden_labels = json.load(f).get("samples", [])
|
45 |
total_samples = len(golden_labels)
|
46 |
num_samples_to_run = st.slider(
|
47 |
-
"
|
48 |
)
|
49 |
except FileNotFoundError:
|
50 |
-
st.error(
|
|
|
|
|
51 |
golden_labels = []
|
52 |
num_samples_to_run = 0
|
53 |
|
54 |
start_button = st.button(
|
55 |
-
"🚀
|
56 |
)
|
57 |
|
58 |
-
# --- Agent
|
59 |
if start_button:
|
60 |
-
#
|
61 |
test_samples = golden_labels[:num_samples_to_run]
|
62 |
|
63 |
config = MODELS_CONFIG.get(model_choice)
|
64 |
model_class = globals()[config["class"]]
|
65 |
model_instance_name = config["model_name"]
|
66 |
|
67 |
-
#
|
68 |
benchmark_helper = MapGuesserBenchmark()
|
69 |
all_results = []
|
70 |
|
71 |
st.info(
|
72 |
-
f"
|
73 |
)
|
74 |
|
75 |
-
|
76 |
-
overall_progress_bar = st.progress(0, text="总进度")
|
77 |
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
|
82 |
|
83 |
-
#
|
84 |
for i, sample in enumerate(test_samples):
|
85 |
sample_id = sample.get("id", "N/A")
|
86 |
st.divider()
|
87 |
-
st.header(f"▶️
|
88 |
|
89 |
-
# 加载地图位置
|
90 |
if not bot.controller.load_location_from_data(sample):
|
91 |
-
st.error(f"
|
92 |
continue
|
93 |
|
94 |
bot.controller.setup_clean_environment()
|
95 |
|
96 |
-
#
|
97 |
col1, col2 = st.columns([2, 3])
|
98 |
with col1:
|
99 |
image_placeholder = st.empty()
|
@@ -101,25 +102,25 @@ if start_button:
|
|
101 |
reasoning_placeholder = st.empty()
|
102 |
action_placeholder = st.empty()
|
103 |
|
104 |
-
# ---
|
105 |
history = []
|
106 |
final_guess = None
|
107 |
|
108 |
for step in range(steps_per_sample):
|
109 |
step_num = step + 1
|
110 |
reasoning_placeholder.info(
|
111 |
-
f"
|
112 |
)
|
113 |
action_placeholder.empty()
|
114 |
|
115 |
-
#
|
116 |
bot.controller.label_arrows_on_screen()
|
117 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
118 |
image_placeholder.image(
|
119 |
screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
|
120 |
)
|
121 |
|
122 |
-
#
|
123 |
history.append(
|
124 |
{
|
125 |
"image_b64": bot.pil_to_base64(
|
@@ -129,7 +130,7 @@ if start_button:
|
|
129 |
}
|
130 |
)
|
131 |
|
132 |
-
#
|
133 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
134 |
remaining_steps=steps_per_sample - step,
|
135 |
history_text="\n".join(
|
@@ -157,12 +158,12 @@ if start_button:
|
|
157 |
)
|
158 |
action_placeholder.success(f"**AI Action:** `{action}`")
|
159 |
|
160 |
-
#
|
161 |
if step_num == steps_per_sample and action != "GUESS":
|
162 |
-
st.warning("
|
163 |
action = "GUESS"
|
164 |
|
165 |
-
#
|
166 |
if action == "GUESS":
|
167 |
lat, lon = (
|
168 |
decision.get("action_details", {}).get("lat"),
|
@@ -171,10 +172,10 @@ if start_button:
|
|
171 |
if lat is not None and lon is not None:
|
172 |
final_guess = (lat, lon)
|
173 |
else:
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
break #
|
178 |
|
179 |
elif action == "MOVE_FORWARD":
|
180 |
bot.controller.move("forward")
|
@@ -185,9 +186,9 @@ if start_button:
|
|
185 |
elif action == "PAN_RIGHT":
|
186 |
bot.controller.pan_view("right")
|
187 |
|
188 |
-
time.sleep(1) #
|
189 |
|
190 |
-
# ---
|
191 |
true_coords = {"lat": sample.get("lat"), "lng": sample.get("lng")}
|
192 |
distance_km = None
|
193 |
is_success = False
|
@@ -197,23 +198,23 @@ if start_button:
|
|
197 |
if distance_km is not None:
|
198 |
is_success = distance_km <= SUCCESS_THRESHOLD_KM
|
199 |
|
200 |
-
st.subheader("🎯
|
201 |
res_col1, res_col2, res_col3 = st.columns(3)
|
202 |
res_col1.metric(
|
203 |
-
"
|
204 |
)
|
205 |
res_col2.metric(
|
206 |
-
"
|
207 |
f"{true_coords['lat']:.3f}, {true_coords['lng']:.3f}",
|
208 |
)
|
209 |
res_col3.metric(
|
210 |
-
"
|
211 |
f"{distance_km:.1f} km" if distance_km is not None else "N/A",
|
212 |
-
delta=f"{'
|
213 |
delta_color=("inverse" if is_success else "off"),
|
214 |
)
|
215 |
else:
|
216 |
-
st.error("Agent
|
217 |
|
218 |
all_results.append(
|
219 |
{
|
@@ -226,22 +227,27 @@ if start_button:
|
|
226 |
}
|
227 |
)
|
228 |
|
229 |
-
#
|
230 |
overall_progress_bar.progress(
|
231 |
-
(i + 1) / num_samples_to_run,
|
|
|
232 |
)
|
233 |
|
234 |
-
# ---
|
235 |
-
bot.close() #
|
236 |
st.divider()
|
237 |
-
st.header("🏁 Benchmark
|
238 |
|
239 |
summary = benchmark_helper.generate_summary(all_results)
|
240 |
if summary and model_choice in summary:
|
241 |
stats = summary[model_choice]
|
242 |
sum_col1, sum_col2 = st.columns(2)
|
243 |
-
sum_col1.metric(
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
246 |
else:
|
247 |
-
st.warning("
|
|
|
6 |
from PIL import Image
|
7 |
from typing import Dict, List, Any
|
8 |
|
9 |
+
# Import core logic and configurations from the project
|
10 |
from geo_bot import (
|
11 |
GeoBot,
|
12 |
AGENT_PROMPT_TEMPLATE,
|
13 |
BENCHMARK_PROMPT,
|
14 |
+
)
|
15 |
from benchmark import MapGuesserBenchmark
|
16 |
from config import MODELS_CONFIG, DATA_PATHS, SUCCESS_THRESHOLD_KM
|
17 |
from langchain_openai import ChatOpenAI
|
18 |
from langchain_anthropic import ChatAnthropic
|
19 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
20 |
|
21 |
+
# --- Page UI Setup ---
|
22 |
st.set_page_config(page_title="MapCrunch AI Agent", layout="wide")
|
23 |
st.title("🗺️ MapCrunch AI Agent")
|
24 |
+
st.caption(
|
25 |
+
"An AI agent that explores and identifies geographic locations through multi-step interaction."
|
26 |
+
)
|
27 |
|
28 |
+
# --- Sidebar for Configuration ---
|
29 |
with st.sidebar:
|
30 |
+
st.header("⚙️ Agent Configuration")
|
31 |
|
32 |
+
# Get API keys from HF Secrets (must be set in Space settings when deploying)
|
33 |
os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY", "")
|
34 |
os.environ["ANTHROPIC_API_KEY"] = st.secrets.get("ANTHROPIC_API_KEY", "")
|
|
|
35 |
# os.environ['GOOGLE_API_KEY'] = st.secrets.get("GOOGLE_API_KEY", "")
|
36 |
|
37 |
+
model_choice = st.selectbox("Select AI Model", list(MODELS_CONFIG.keys()))
|
38 |
steps_per_sample = st.slider(
|
39 |
+
"Max Exploration Steps per Sample", min_value=3, max_value=20, value=10
|
40 |
)
|
41 |
|
42 |
+
# Load golden labels for selection
|
43 |
try:
|
44 |
with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
|
45 |
golden_labels = json.load(f).get("samples", [])
|
46 |
total_samples = len(golden_labels)
|
47 |
num_samples_to_run = st.slider(
|
48 |
+
"Number of Samples to Test", min_value=1, max_value=total_samples, value=3
|
49 |
)
|
50 |
except FileNotFoundError:
|
51 |
+
st.error(
|
52 |
+
f"Data file '{DATA_PATHS['golden_labels']}' not found. Please prepare the data."
|
53 |
+
)
|
54 |
golden_labels = []
|
55 |
num_samples_to_run = 0
|
56 |
|
57 |
start_button = st.button(
|
58 |
+
"🚀 Start Agent Benchmark", disabled=(num_samples_to_run == 0), type="primary"
|
59 |
)
|
60 |
|
61 |
+
# --- Agent Execution Logic ---
|
62 |
if start_button:
|
63 |
+
# Prepare the environment
|
64 |
test_samples = golden_labels[:num_samples_to_run]
|
65 |
|
66 |
config = MODELS_CONFIG.get(model_choice)
|
67 |
model_class = globals()[config["class"]]
|
68 |
model_instance_name = config["model_name"]
|
69 |
|
70 |
+
# Initialize helpers and result lists
|
71 |
benchmark_helper = MapGuesserBenchmark()
|
72 |
all_results = []
|
73 |
|
74 |
st.info(
|
75 |
+
f"Starting Agent Benchmark... Model: {model_choice}, Steps: {steps_per_sample}, Samples: {num_samples_to_run}"
|
76 |
)
|
77 |
|
78 |
+
overall_progress_bar = st.progress(0, text="Overall Progress")
|
|
|
79 |
|
80 |
+
# Initialize the bot outside the loop to reuse the browser instance for efficiency
|
81 |
+
with st.spinner("Initializing browser and AI model..."):
|
82 |
+
# Note: Must run in headless mode on HF Spaces
|
83 |
bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
|
84 |
|
85 |
+
# Main loop to iterate through all selected test samples
|
86 |
for i, sample in enumerate(test_samples):
|
87 |
sample_id = sample.get("id", "N/A")
|
88 |
st.divider()
|
89 |
+
st.header(f"▶️ Running Sample {i + 1}/{num_samples_to_run} (ID: {sample_id})")
|
90 |
|
|
|
91 |
if not bot.controller.load_location_from_data(sample):
|
92 |
+
st.error(f"Failed to load location for sample {sample_id}. Skipping.")
|
93 |
continue
|
94 |
|
95 |
bot.controller.setup_clean_environment()
|
96 |
|
97 |
+
# Create the visualization layout for the current sample
|
98 |
col1, col2 = st.columns([2, 3])
|
99 |
with col1:
|
100 |
image_placeholder = st.empty()
|
|
|
102 |
reasoning_placeholder = st.empty()
|
103 |
action_placeholder = st.empty()
|
104 |
|
105 |
+
# --- Inner agent exploration loop ---
|
106 |
history = []
|
107 |
final_guess = None
|
108 |
|
109 |
for step in range(steps_per_sample):
|
110 |
step_num = step + 1
|
111 |
reasoning_placeholder.info(
|
112 |
+
f"Thinking... (Step {step_num}/{steps_per_sample})"
|
113 |
)
|
114 |
action_placeholder.empty()
|
115 |
|
116 |
+
# Observe and label arrows
|
117 |
bot.controller.label_arrows_on_screen()
|
118 |
screenshot_bytes = bot.controller.take_street_view_screenshot()
|
119 |
image_placeholder.image(
|
120 |
screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
|
121 |
)
|
122 |
|
123 |
+
# Update history
|
124 |
history.append(
|
125 |
{
|
126 |
"image_b64": bot.pil_to_base64(
|
|
|
130 |
}
|
131 |
)
|
132 |
|
133 |
+
# Think
|
134 |
prompt = AGENT_PROMPT_TEMPLATE.format(
|
135 |
remaining_steps=steps_per_sample - step,
|
136 |
history_text="\n".join(
|
|
|
158 |
)
|
159 |
action_placeholder.success(f"**AI Action:** `{action}`")
|
160 |
|
161 |
+
# Force a GUESS on the last step
|
162 |
if step_num == steps_per_sample and action != "GUESS":
|
163 |
+
st.warning("Max steps reached. Forcing a GUESS action.")
|
164 |
action = "GUESS"
|
165 |
|
166 |
+
# Act
|
167 |
if action == "GUESS":
|
168 |
lat, lon = (
|
169 |
decision.get("action_details", {}).get("lat"),
|
|
|
172 |
if lat is not None and lon is not None:
|
173 |
final_guess = (lat, lon)
|
174 |
else:
|
175 |
+
st.error(
|
176 |
+
"GUESS action was missing coordinates. Guess failed for this sample."
|
177 |
+
)
|
178 |
+
break # End exploration for the current sample
|
179 |
|
180 |
elif action == "MOVE_FORWARD":
|
181 |
bot.controller.move("forward")
|
|
|
186 |
elif action == "PAN_RIGHT":
|
187 |
bot.controller.pan_view("right")
|
188 |
|
189 |
+
time.sleep(1) # A brief pause between steps for better visualization
|
190 |
|
191 |
+
# --- End of single sample run, calculate and display results ---
|
192 |
true_coords = {"lat": sample.get("lat"), "lng": sample.get("lng")}
|
193 |
distance_km = None
|
194 |
is_success = False
|
|
|
198 |
if distance_km is not None:
|
199 |
is_success = distance_km <= SUCCESS_THRESHOLD_KM
|
200 |
|
201 |
+
st.subheader("🎯 Round Result")
|
202 |
res_col1, res_col2, res_col3 = st.columns(3)
|
203 |
res_col1.metric(
|
204 |
+
"Final Guess (Lat, Lon)", f"{final_guess[0]:.3f}, {final_guess[1]:.3f}"
|
205 |
)
|
206 |
res_col2.metric(
|
207 |
+
"Ground Truth (Lat, Lon)",
|
208 |
f"{true_coords['lat']:.3f}, {true_coords['lng']:.3f}",
|
209 |
)
|
210 |
res_col3.metric(
|
211 |
+
"Distance Error",
|
212 |
f"{distance_km:.1f} km" if distance_km is not None else "N/A",
|
213 |
+
delta=f"{'Success' if is_success else 'Failure'}",
|
214 |
delta_color=("inverse" if is_success else "off"),
|
215 |
)
|
216 |
else:
|
217 |
+
st.error("Agent failed to make a final guess.")
|
218 |
|
219 |
all_results.append(
|
220 |
{
|
|
|
227 |
}
|
228 |
)
|
229 |
|
230 |
+
# Update overall progress bar
|
231 |
overall_progress_bar.progress(
|
232 |
+
(i + 1) / num_samples_to_run,
|
233 |
+
text=f"Overall Progress: {i + 1}/{num_samples_to_run}",
|
234 |
)
|
235 |
|
236 |
+
# --- End of all samples, display final summary ---
|
237 |
+
bot.close() # Close the browser
|
238 |
st.divider()
|
239 |
+
st.header("🏁 Benchmark Summary")
|
240 |
|
241 |
summary = benchmark_helper.generate_summary(all_results)
|
242 |
if summary and model_choice in summary:
|
243 |
stats = summary[model_choice]
|
244 |
sum_col1, sum_col2 = st.columns(2)
|
245 |
+
sum_col1.metric(
|
246 |
+
"Overall Success Rate", f"{stats.get('success_rate', 0) * 100:.1f} %"
|
247 |
+
)
|
248 |
+
sum_col2.metric(
|
249 |
+
"Average Distance Error", f"{stats.get('average_distance_km', 0):.1f} km"
|
250 |
+
)
|
251 |
+
st.dataframe(all_results) # Display the detailed results table
|
252 |
else:
|
253 |
+
st.warning("Not enough results to generate a summary.")
|