Andy Lee commited on
Commit
d6d949c
·
1 Parent(s): fc843ce

feat: real hf studio gui

Browse files
Files changed (1) hide show
  1. app.py +182 -216
app.py CHANGED
@@ -1,247 +1,213 @@
1
- import streamlit as st
2
  import json
3
  import os
4
  import time
5
  from io import BytesIO
6
  from PIL import Image
7
- from typing import Dict, List, Any
8
 
9
  # 导入项目的核心逻辑和配置
10
- from geo_bot import (
11
- GeoBot,
12
- AGENT_PROMPT_TEMPLATE,
13
- BENCHMARK_PROMPT,
14
- ) # 导入Prompt模板以供复用
15
  from benchmark import MapGuesserBenchmark
16
- from config import MODELS_CONFIG, DATA_PATHS, SUCCESS_THRESHOLD_KM
17
  from langchain_openai import ChatOpenAI
18
  from langchain_anthropic import ChatAnthropic
19
  from langchain_google_genai import ChatGoogleGenerativeAI
20
 
21
- # --- 页面UI设置 ---
22
- st.set_page_config(page_title="MapCrunch AI Agent", layout="wide")
23
- st.title("🗺️ MapCrunch AI Agent")
24
- st.caption("一个通过多步交互探索和识别地理位置的AI智能体")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # --- Sidebar用于配置 ---
27
- with st.sidebar:
28
- st.header("⚙️ 运行配置")
 
29
 
30
- # 从HF Secrets获取API密钥 (部署到HF Spaces时,需要在Settings->Secrets中设置)
31
- os.environ["OPENAI_API_KEY"] = st.secrets.get("OPENAI_API_KEY", "")
32
- os.environ["ANTHROPIC_API_KEY"] = st.secrets.get("ANTHROPIC_API_KEY", "")
33
- # 添加其他你可能需要的API密钥
34
- # os.environ['GOOGLE_API_KEY'] = st.secrets.get("GOOGLE_API_KEY", "")
35
 
36
- model_choice = st.selectbox("选择AI模型", list(MODELS_CONFIG.keys()))
37
- steps_per_sample = st.slider(
38
- "每轮最大探索步数", min_value=3, max_value=20, value=10
39
- )
40
 
41
- # 加载golden labels以供选择
42
- try:
43
- with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
44
- golden_labels = json.load(f).get("samples", [])
45
- total_samples = len(golden_labels)
46
- num_samples_to_run = st.slider(
47
- "选择测试样本数量", min_value=1, max_value=total_samples, value=3
48
- )
49
- except FileNotFoundError:
50
- st.error(f"数据文件 '{DATA_PATHS['golden_labels']}' 未找到。请先准备数据。")
51
- golden_labels = []
52
- num_samples_to_run = 0
53
 
54
- start_button = st.button(
55
- "🚀 启动Agent Benchmark", disabled=(num_samples_to_run == 0), type="primary"
56
- )
57
 
58
- # --- Agent运行逻辑 ---
59
- if start_button:
60
- # 准备运行环境
61
- test_samples = golden_labels[:num_samples_to_run]
62
 
63
- config = MODELS_CONFIG.get(model_choice)
64
- model_class = globals()[config["class"]]
65
- model_instance_name = config["model_name"]
66
-
67
- # 初始化用于统计结果的辅助类和列表
68
- benchmark_helper = MapGuesserBenchmark()
69
- all_results = []
70
-
71
- st.info(
72
- f"即将开始Agent Benchmark... 模型: {model_choice}, 步数: {steps_per_sample}, 样本数: {num_samples_to_run}"
73
- )
74
 
75
- # 创建一个总进度条
76
- overall_progress_bar = st.progress(0, text="总进度")
77
-
78
- # 初始化Bot (注意:在HF Spaces上,必须以headless模式运行)
79
- # 将Bot的初始化放在循环外,可以复用同一个浏览器实例,提高效率
80
- with st.spinner("正在初始化浏览器和AI模型..."):
81
- bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
82
-
83
- # 主循环,遍历所有选择的测试样本
84
- for i, sample in enumerate(test_samples):
85
- sample_id = sample.get("id", "N/A")
86
- st.divider()
87
- st.header(f"▶️ 运行中:样本 {i + 1}/{num_samples_to_run} (ID: {sample_id})")
88
-
89
- # 加载地图位置
90
- if not bot.controller.load_location_from_data(sample):
91
- st.error(f"加载样本 {sample_id} 失败,已跳过。")
92
- continue
93
-
94
- bot.controller.setup_clean_environment()
95
-
96
- # 为当前样本创建可视化布局
97
- col1, col2 = st.columns([2, 3])
98
- with col1:
99
- image_placeholder = st.empty()
100
- with col2:
101
- reasoning_placeholder = st.empty()
102
- action_placeholder = st.empty()
103
-
104
- # --- 内部的Agent探索循环 ---
105
- history = []
106
- final_guess = None
107
-
108
- for step in range(steps_per_sample):
109
- step_num = step + 1
110
- reasoning_placeholder.info(
111
- f"思考中... (第 {step_num}/{steps_per_sample} 步)"
112
- )
113
- action_placeholder.empty()
114
 
115
- # 观察并标记箭头
116
- bot.controller.label_arrows_on_screen()
117
- screenshot_bytes = bot.controller.take_street_view_screenshot()
118
- image_placeholder.image(
119
- screenshot_bytes, caption=f"Step {step_num} View", use_column_width=True
120
- )
121
 
122
- # 更新历史
123
- history.append(
124
- {
125
- "image_b64": bot.pil_to_base64(
126
- Image.open(BytesIO(screenshot_bytes))
127
- ),
128
- "action": "N/A",
129
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- # 思考
133
- prompt = AGENT_PROMPT_TEMPLATE.format(
134
- remaining_steps=steps_per_sample - step,
135
- history_text="\n".join(
136
- [f"Step {j + 1}: {h['action']}" for j, h in enumerate(history)]
137
- ),
138
- available_actions=json.dumps(bot.controller.get_available_actions()),
139
- )
140
- message = bot._create_message_with_history(
141
- prompt, [h["image_b64"] for h in history]
142
- )
143
- response = bot.model.invoke(message)
144
- decision = bot._parse_agent_response(response)
145
 
146
- if not decision: # Fallback
147
- decision = {
148
- "action_details": {"action": "PAN_RIGHT"},
149
- "reasoning": "Default recovery.",
150
- }
151
 
152
- action = decision.get("action_details", {}).get("action")
153
- history[-1]["action"] = action
 
 
154
 
155
- reasoning_placeholder.info(
156
- f"**AI Reasoning:**\n\n{decision.get('reasoning', 'N/A')}"
157
- )
158
- action_placeholder.success(f"**AI Action:** `{action}`")
159
-
160
- # 强制在最后一步进行GUESS
161
- if step_num == steps_per_sample and action != "GUESS":
162
- st.warning("已达最大步数,强制执行GUESS。")
163
- action = "GUESS"
164
-
165
- # 行动
166
- if action == "GUESS":
167
- lat, lon = (
168
- decision.get("action_details", {}).get("lat"),
169
- decision.get("action_details", {}).get("lon"),
170
- )
171
- if lat is not None and lon is not None:
172
- final_guess = (lat, lon)
173
- else:
174
- # 如果AI没在GUESS时提供坐标,再问一次
175
- # (这里的简化处理是直接结束,但在更复杂的版本可以再调用一次AI)
176
- st.error("GUESS动作中缺少坐标,本次猜测失败。")
177
- break # 结束当前样本的探索
178
-
179
- elif action == "MOVE_FORWARD":
180
- bot.controller.move("forward")
181
- elif action == "MOVE_BACKWARD":
182
- bot.controller.move("backward")
183
- elif action == "PAN_LEFT":
184
- bot.controller.pan_view("left")
185
- elif action == "PAN_RIGHT":
186
- bot.controller.pan_view("right")
187
-
188
- time.sleep(1) # 在步骤之间稍作停顿,改善视觉效果
189
-
190
- # --- 单个样本运行结束,计算并展示结果 ---
191
- true_coords = {"lat": sample.get("lat"), "lng": sample.get("lng")}
192
- distance_km = None
193
- is_success = False
194
-
195
- if final_guess:
196
- distance_km = benchmark_helper.calculate_distance(true_coords, final_guess)
197
- if distance_km is not None:
198
- is_success = distance_km <= SUCCESS_THRESHOLD_KM
199
-
200
- st.subheader("🎯 本轮结果")
201
- res_col1, res_col2, res_col3 = st.columns(3)
202
- res_col1.metric(
203
- "最终猜测 (Lat, Lon)", f"{final_guess[0]:.3f}, {final_guess[1]:.3f}"
204
  )
205
- res_col2.metric(
206
- "真实位置 (Lat, Lon)",
207
- f"{true_coords['lat']:.3f}, {true_coords['lng']:.3f}",
208
  )
209
- res_col3.metric(
210
- "距离误差",
211
- f"{distance_km:.1f} km" if distance_km is not None else "N/A",
212
- delta=f"{'成功' if is_success else '失败'}",
213
- delta_color=("inverse" if is_success else "off"),
214
  )
215
- else:
216
- st.error("Agent 未能做出最终猜测。")
217
-
218
- all_results.append(
219
- {
220
- "sample_id": sample_id,
221
- "model": model_choice,
222
- "true_coordinates": true_coords,
223
- "predicted_coordinates": final_guess,
224
- "distance_km": distance_km,
225
- "success": is_success,
226
- }
227
- )
228
-
229
- # 更新总进度条
230
- overall_progress_bar.progress(
231
- (i + 1) / num_samples_to_run, text=f"总进度: {i + 1}/{num_samples_to_run}"
232
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- # --- 所有样本运行完毕,显示最终摘要 ---
235
- bot.close() # 关闭浏览器
236
- st.divider()
237
- st.header("🏁 Benchmark 最终摘要")
238
-
239
- summary = benchmark_helper.generate_summary(all_results)
240
- if summary and model_choice in summary:
241
- stats = summary[model_choice]
242
- sum_col1, sum_col2 = st.columns(2)
243
- sum_col1.metric("总成功率", f"{stats.get('success_rate', 0) * 100:.1f} %")
244
- sum_col2.metric("平均距离误差", f"{stats.get('average_distance_km', 0):.1f} km")
245
- st.dataframe(all_results) # 显示详细结果表格
246
- else:
247
- st.warning("没有足够的结果来生成摘要。")
 
1
+ import gradio as gr
2
  import json
3
  import os
4
  import time
5
  from io import BytesIO
6
  from PIL import Image
 
7
 
8
  # 导入项目的核心逻辑和配置
9
+ from geo_bot import GeoBot, AGENT_PROMPT_TEMPLATE
 
 
 
 
10
  from benchmark import MapGuesserBenchmark
11
+ from config import MODELS_CONFIG, DATA_PATHS
12
  from langchain_openai import ChatOpenAI
13
  from langchain_anthropic import ChatAnthropic
14
  from langchain_google_genai import ChatGoogleGenerativeAI
15
 
16
+ # --- 全局设置 ---
17
+ # 从HF Secrets安全地读取API密钥
18
+ os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "")
19
+ os.environ["ANTHROPIC_API_KEY"] = os.environ.get("ANTHROPIC_API_KEY", "")
20
+ # os.environ['GOOGLE_API_KEY'] = os.environ.get("GOOGLE_API_KEY", "")
21
+
22
+ # 加载golden labels数据
23
+ try:
24
+ with open(DATA_PATHS["golden_labels"], "r", encoding="utf-8") as f:
25
+ GOLDEN_LABELS = json.load(f).get("samples", [])
26
+ except FileNotFoundError:
27
+ print(f"警告: 数据文件 '{DATA_PATHS['golden_labels']}' 未找到。")
28
+ GOLDEN_LABELS = []
29
+
30
+
31
+ # --- 核心处理函数 (使用yield实现流式更新) ---
32
+ def run_agent_process(
33
+ model_choice, steps_per_sample, sample_index, progress=gr.Progress(track_tqdm=True)
34
+ ):
35
+ """
36
+ 这个函数是整个应用的引擎,它是一个生成器 (generator),会逐步yield更新。
37
+ """
38
+ # 1. 初始化环境
39
+ yield {
40
+ status_text: "状态: 正在初始化浏览器和AI模型...",
41
+ image_output: None,
42
+ reasoning_output: "",
43
+ action_output: "",
44
+ result_output: "",
45
+ }
46
 
47
+ config = MODELS_CONFIG.get(model_choice)
48
+ model_class = globals()[config["class"]]
49
+ model_instance_name = config["model_name"]
50
+ bot = GeoBot(model=model_class, model_name=model_instance_name, headless=True)
51
 
52
+ # 2. 加载选定的样本位置
53
+ sample = GOLDEN_LABELS[sample_index]
54
+ ground_truth = {"lat": sample.get("lat"), "lng": sample.get("lng")}
 
 
55
 
56
+ if not bot.controller.load_location_from_data(sample):
57
+ yield {status_text: "错误: 加载地图位置失败。请重试。"}
58
+ return
 
59
 
60
+ bot.controller.setup_clean_environment()
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ history = []
63
+ final_guess = None
 
64
 
65
+ # 3. 开始多步探索循环
66
+ for step in range(steps_per_sample):
67
+ step_num = step + 1
68
+ yield {status_text: f"状态: 探索中... (第 {step_num}/{steps_per_sample} 步)"}
69
 
70
+ # a. 观察 (Observe)
71
+ bot.controller.label_arrows_on_screen()
72
+ screenshot_bytes = bot.controller.take_street_view_screenshot()
 
 
 
 
 
 
 
 
73
 
74
+ # b. 思考 (Think)
75
+ current_screenshot_b64 = bot.pil_to_base64(
76
+ Image.open(BytesIO(screenshot_bytes))
77
+ )
78
+ history.append({"image_b64": current_screenshot_b64, "action": "N/A"})
79
+
80
+ prompt = AGENT_PROMPT_TEMPLATE.format(
81
+ remaining_steps=steps_per_sample - step,
82
+ history_text="\n".join(
83
+ [f"Step {j + 1}: {h['action']}" for j, h in enumerate(history)]
84
+ ),
85
+ available_actions=json.dumps(bot.controller.get_available_actions()),
86
+ )
87
+ message = bot._create_message_with_history(
88
+ prompt, [h["image_b64"] for h in history]
89
+ )
90
+ response = bot.model.invoke(message)
91
+ decision = bot._parse_agent_response(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ if not decision:
94
+ decision = {
95
+ "action_details": {"action": "PAN_RIGHT"},
96
+ "reasoning": "Default recovery.",
97
+ }
 
98
 
99
+ action = decision.get("action_details", {}).get("action")
100
+ reasoning = decision.get("reasoning", "N/A")
101
+ history[-1]["action"] = action
102
+
103
+ # c. 更新UI
104
+ yield {
105
+ image_output: Image.open(BytesIO(screenshot_bytes)),
106
+ reasoning_output: f"**AI Reasoning:**\n\n{reasoning}",
107
+ action_output: f"**AI Action:** `{action}`",
108
+ }
109
+
110
+ # d. 强制在最后一步猜测
111
+ if step_num == steps_per_sample and action != "GUESS":
112
+ action = "GUESS"
113
+ yield {status_text: "状态: 已达最大步数,强制执行GUESS..."}
114
+
115
+ # e. 行动 (Act)
116
+ if action == "GUESS":
117
+ lat, lon = (
118
+ decision.get("action_details", {}).get("lat"),
119
+ decision.get("action_details", {}).get("lon"),
120
  )
121
+ if lat is not None and lon is not None:
122
+ final_guess = (lat, lon)
123
+ break
124
+ elif action == "MOVE_FORWARD":
125
+ bot.controller.move("forward")
126
+ elif action == "MOVE_BACKWARD":
127
+ bot.controller.move("backward")
128
+ elif action == "PAN_LEFT":
129
+ bot.controller.pan_view("left")
130
+ elif action == "PAN_RIGHT":
131
+ bot.controller.pan_view("right")
132
+
133
+ time.sleep(1) # 步骤间稍作停顿
134
+
135
+ # 4. 循环结束,计算最终结果并更新UI
136
+ yield {status_text: "状态: 探索完成,正在计算最终结果..."}
137
+
138
+ if final_guess:
139
+ distance = bot.calculate_distance(ground_truth, final_guess)
140
+ result_text = f"""
141
+ ### 📍 最终结果
142
+ - **真实位置:** `Lat: {ground_truth["lat"]:.4f}, Lon: {ground_truth["lng"]:.4f}`
143
+ - **Agent猜测:** `Lat: {final_guess[0]:.4f}, Lon: {final_guess[1]:.4f}`
144
+ - **距离误差:** `{distance:.1f} km`
145
+ """
146
+ yield {result_output: result_text, status_text: "状态: 完成!"}
147
+ else:
148
+ yield {
149
+ result_output: "### 📍 最终结果\n\nAgent 未能做出有效猜测。",
150
+ status_text: "状态: 完成!",
151
+ }
152
 
153
+ bot.close() # 关闭浏览器
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
155
 
156
+ # --- Gradio UI 布局 ---
157
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
+ gr.Markdown("# 🗺️ 可视化 GeoBot 智能体")
159
+ gr.Markdown("选择配置并启动Agent,观察它如何通过探索来猜测自己的地理位置。")
160
 
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ gr.Markdown("## ⚙️ 控制面板")
164
+ model_choice = gr.Dropdown(
165
+ list(MODELS_CONFIG.keys()), label="选择AI模型", value="gpt-4o"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
+ steps_per_sample = gr.Slider(
168
+ 3, 20, value=10, step=1, label="每轮最大探索步数"
 
169
  )
170
+ sample_index = gr.Dropdown(
171
+ [f"样本 {i}" for i in range(len(GOLDEN_LABELS))],
172
+ label="选择测试样本",
173
+ value="样本 0",
 
174
  )
175
+ start_button = gr.Button("🚀 启动智能体", variant="primary")
176
+ status_text = gr.Markdown("状态: 等待启动")
177
+ result_output = gr.Markdown()
178
+
179
+ with gr.Column(scale=3):
180
+ gr.Markdown("## 🕵️ Agent探索过程")
181
+ image_output = gr.Image(label="Agent当前视角", height=600)
182
+ with gr.Row():
183
+ reasoning_output = gr.Markdown(label="AI 思考")
184
+ action_output = gr.Markdown(label="AI 行动")
185
+
186
+ # 将按钮点击事件连接到核心函数
187
+ # `lambda s: int(s.split(' ')[1])` 用于从"样本 0"中提取出数字0
188
+ start_button.click(
189
+ fn=run_agent_process,
190
+ inputs=[model_choice, steps_per_sample, sample_index],
191
+ outputs=[
192
+ status_text,
193
+ image_output,
194
+ reasoning_output,
195
+ action_output,
196
+ result_output,
197
+ ],
198
+ # `js` 参数用于在点击按钮后禁用它,防止重复点击
199
+ js="""
200
+ (model_choice, steps_per_sample, sample_index) => {
201
+ return [
202
+ "状态: 初始化中...",
203
+ null,
204
+ "...",
205
+ "...",
206
+ ""
207
+ ];
208
+ }
209
+ """,
210
+ )
211
 
212
+ if __name__ == "__main__":
213
+ demo.launch()