yichuan commited on
Commit
1647cf8
·
unverified ·
2 Parent(s): d79cd74 aeb6ce6

Merge pull request #2 from yichuan520030910320/qwen

Browse files
Files changed (5) hide show
  1. app.py +78 -71
  2. benchmark.py +4 -21
  3. config.py +97 -11
  4. geo_bot.py +0 -6
  5. main.py +2 -8
app.py CHANGED
@@ -3,38 +3,31 @@ import json
3
  import os
4
  import time
5
  import re
6
- from io import BytesIO
7
- from PIL import Image
8
  from pathlib import Path
9
- import pyperclip
10
 
11
- from geo_bot import GeoBot, AGENT_PROMPT_TEMPLATE
12
  from benchmark import MapGuesserBenchmark
13
- from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM, DEFAULT_MODEL, DEFAULT_TEMPERATURE
14
- from langchain_openai import ChatOpenAI
15
- from langchain_anthropic import ChatAnthropic
16
- from langchain_google_genai import ChatGoogleGenerativeAI
17
- from hf_chat import HuggingFaceChat
18
-
19
- # Simple API key setup
20
- if "OPENAI_API_KEY" in st.secrets:
21
- os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
22
- if "ANTHROPIC_API_KEY" in st.secrets:
23
- os.environ["ANTHROPIC_API_KEY"] = st.secrets["ANTHROPIC_API_KEY"]
24
- if "GOOGLE_API_KEY" in st.secrets:
25
- os.environ["GOOGLE_API_KEY"] = st.secrets["GOOGLE_API_KEY"]
26
- if "HF_TOKEN" in st.secrets:
27
- os.environ["HF_TOKEN"] = st.secrets["HF_TOKEN"]
28
 
29
 
30
  def convert_google_to_mapcrunch_url(google_url):
31
  """Convert Google Maps URL to MapCrunch URL format."""
32
  try:
33
  # Extract coordinates using regex
34
- match = re.search(r'@(-?\d+\.\d+),(-?\d+\.\d+)', google_url)
35
  if not match:
36
  return None
37
-
38
  lat, lon = match.groups()
39
  # MapCrunch format: lat_lon_heading_pitch_zoom
40
  # Using default values for heading (317.72), pitch (0.86), and zoom (0)
@@ -58,21 +51,10 @@ def get_available_datasets():
58
  return datasets if datasets else ["default"]
59
 
60
 
61
- def get_model_class(class_name):
62
- if class_name == "ChatOpenAI":
63
- return ChatOpenAI
64
- elif class_name == "ChatAnthropic":
65
- return ChatAnthropic
66
- elif class_name == "ChatGoogleGenerativeAI":
67
- return ChatGoogleGenerativeAI
68
- elif class_name == "HuggingFaceChat":
69
- return HuggingFaceChat
70
- else:
71
- raise ValueError(f"Unknown model class: {class_name}")
72
-
73
-
74
  # UI Setup
75
- st.set_page_config(page_title="🧠 Omniscient - Multiturn Geographic Intelligence", layout="wide")
 
 
76
  st.title("🧠 Omniscient")
77
  st.markdown("""
78
  ### *An all-seeing AI agent for geographic analysis and deduction*
@@ -86,14 +68,18 @@ with st.sidebar:
86
 
87
  # Mode selection
88
  mode = st.radio("Mode", ["Dataset Mode", "Online Mode"], index=0)
89
-
90
  if mode == "Dataset Mode":
91
  # Get available datasets and ensure we have a valid default
92
  available_datasets = get_available_datasets()
93
  default_dataset = available_datasets[0] if available_datasets else "default"
94
-
95
  dataset_choice = st.selectbox("Dataset", available_datasets, index=0)
96
- model_choice = st.selectbox("Model", list(MODELS_CONFIG.keys()), index=list(MODELS_CONFIG.keys()).index(DEFAULT_MODEL))
 
 
 
 
97
  steps_per_sample = st.slider("Max Steps", 1, 20, 10)
98
  temperature = st.slider(
99
  "Temperature",
@@ -109,14 +95,16 @@ with st.sidebar:
109
  try:
110
  with open(data_paths["golden_labels"], "r") as f:
111
  golden_labels = json.load(f).get("samples", [])
112
-
113
  st.info(f"Dataset '{dataset_choice}' has {len(golden_labels)} samples")
114
  if len(golden_labels) == 0:
115
  st.error(f"Dataset '{dataset_choice}' contains no samples!")
116
  st.stop()
117
-
118
  except FileNotFoundError:
119
- st.error(f"❌ Dataset '{dataset_choice}' not found at {data_paths['golden_labels']}")
 
 
120
  st.info("💡 Available datasets: " + ", ".join(available_datasets))
121
  st.stop()
122
  except Exception as e:
@@ -128,19 +116,21 @@ with st.sidebar:
128
  )
129
  else: # Online Mode
130
  st.info("Enter a URL to analyze a specific location")
131
-
132
  # Add example URLs
133
  example_google_url = "https://www.google.com/maps/@37.8728123,-122.2445339,3a,75y,3.36h,90t/data=!3m7!1e1!3m5!1s4DTABKOpCL6hdNRgnAHTgw!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fcb_client%3Dmaps_sv.tactile%26w%3D900%26h%3D600%26pitch%3D0%26panoid%3D4DTABKOpCL6hdNRgnAHTgw%26yaw%3D3.3576431!7i13312!8i6656?entry=ttu"
134
- example_mapcrunch_url = "http://www.mapcrunch.com/p/37.882284_-122.269626_293.91_-6.63_0"
135
-
 
 
136
  # Create tabs for different URL types
137
  input_tab1, input_tab2 = st.tabs(["Google Maps URL", "MapCrunch URL"])
138
-
139
  google_url = ""
140
  mapcrunch_url = ""
141
  golden_labels = None
142
  num_samples = None
143
-
144
  with input_tab1:
145
  url_col1, url_col2 = st.columns([3, 1])
146
  with url_col1:
@@ -149,56 +139,68 @@ with st.sidebar:
149
  placeholder="https://www.google.com/maps/@37.5851338,-122.1519467,9z?entry=ttu",
150
  key="google_maps_url",
151
  )
152
- st.markdown(f"💡 **Example Location:** [View in Google Maps]({example_google_url})")
 
 
153
  if google_url:
154
  mapcrunch_url_converted = convert_google_to_mapcrunch_url(google_url)
155
  if mapcrunch_url_converted:
156
  st.success(f"Converted to MapCrunch URL: {mapcrunch_url_converted}")
157
  try:
158
- golden_labels = [{
159
- "id": "online",
160
- "lat": float(re.search(r'@(-?\d+\.\d+),(-?\d+\.\d+)', google_url).group(1)),
161
- "lng": float(re.search(r'@(-?\d+\.\d+),(-?\d+\.\d+)', google_url).group(2)),
162
- "url": mapcrunch_url_converted
163
- }]
 
 
 
 
 
 
 
 
 
164
  num_samples = 1
165
  except Exception as e:
166
  st.error(f"Invalid Google Maps URL format: {str(e)}")
167
  else:
168
  st.error("Invalid Google Maps URL format")
169
-
170
  with input_tab2:
171
  st.markdown("💡 **Example Location:**")
172
  st.markdown(f"[View in MapCrunch]({example_mapcrunch_url})")
173
  st.code(example_mapcrunch_url, language="text")
174
  mapcrunch_url = st.text_input(
175
- "MapCrunch URL",
176
- placeholder=example_mapcrunch_url,
177
- key="mapcrunch_url"
178
  )
179
  if mapcrunch_url:
180
  try:
181
- coords = mapcrunch_url.split('/')[-1].split('_')
182
  lat, lon = float(coords[0]), float(coords[1])
183
- golden_labels = [{
184
- "id": "online",
185
- "lat": lat,
186
- "lng": lon,
187
- "url": mapcrunch_url
188
- }]
189
  num_samples = 1
190
  except Exception as e:
191
  st.error(f"Invalid MapCrunch URL format: {str(e)}")
192
-
193
  # Only stop if neither input is provided
194
  if not google_url and not mapcrunch_url:
195
- st.warning("Please enter a Google Maps URL or MapCrunch URL, or use the example above.")
 
 
196
  st.stop()
197
  if golden_labels is None or num_samples is None:
198
  st.warning("Please enter a valid URL.")
199
  st.stop()
200
-
201
- model_choice = st.selectbox("Model", list(MODELS_CONFIG.keys()), index=list(MODELS_CONFIG.keys()).index(DEFAULT_MODEL))
 
 
 
 
202
  steps_per_sample = st.slider("Max Steps", 1, 20, 10)
203
  temperature = st.slider(
204
  "Temperature",
@@ -217,7 +219,9 @@ if start_button:
217
  config = MODELS_CONFIG[model_choice]
218
  model_class = get_model_class(config["class"])
219
 
220
- benchmark_helper = MapGuesserBenchmark(dataset_name=dataset_choice if mode == "Dataset Mode" else "online")
 
 
221
  all_results = []
222
 
223
  progress_bar = st.progress(0)
@@ -238,7 +242,7 @@ if start_button:
238
  else:
239
  # Load from dataset as before
240
  bot.controller.load_location_from_data(sample)
241
-
242
  bot.controller.setup_clean_environment()
243
 
244
  # Create containers for UI updates
@@ -421,7 +425,10 @@ if start_button:
421
  mime="application/json",
422
  )
423
 
 
424
  def handle_tab_completion():
425
  """Handle tab completion for the Google Maps URL input."""
426
  if st.session_state.google_maps_url == "":
427
- st.session_state.google_maps_url = "https://www.google.com/maps/@37.5851338,-122.1519467,9z?entry=ttu"
 
 
 
3
  import os
4
  import time
5
  import re
 
 
6
  from pathlib import Path
 
7
 
8
+ from geo_bot import GeoBot
9
  from benchmark import MapGuesserBenchmark
10
+ from config import (
11
+ MODELS_CONFIG,
12
+ get_data_paths,
13
+ SUCCESS_THRESHOLD_KM,
14
+ get_model_class,
15
+ DEFAULT_MODEL,
16
+ DEFAULT_TEMPERATURE,
17
+ setup_environment_variables,
18
+ )
19
+
20
+ setup_environment_variables(st.secrets)
 
 
 
 
21
 
22
 
23
  def convert_google_to_mapcrunch_url(google_url):
24
  """Convert Google Maps URL to MapCrunch URL format."""
25
  try:
26
  # Extract coordinates using regex
27
+ match = re.search(r"@(-?\d+\.\d+),(-?\d+\.\d+)", google_url)
28
  if not match:
29
  return None
30
+
31
  lat, lon = match.groups()
32
  # MapCrunch format: lat_lon_heading_pitch_zoom
33
  # Using default values for heading (317.72), pitch (0.86), and zoom (0)
 
51
  return datasets if datasets else ["default"]
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # UI Setup
55
+ st.set_page_config(
56
+ page_title="🧠 Omniscient - Multiturn Geographic Intelligence", layout="wide"
57
+ )
58
  st.title("🧠 Omniscient")
59
  st.markdown("""
60
  ### *An all-seeing AI agent for geographic analysis and deduction*
 
68
 
69
  # Mode selection
70
  mode = st.radio("Mode", ["Dataset Mode", "Online Mode"], index=0)
71
+
72
  if mode == "Dataset Mode":
73
  # Get available datasets and ensure we have a valid default
74
  available_datasets = get_available_datasets()
75
  default_dataset = available_datasets[0] if available_datasets else "default"
76
+
77
  dataset_choice = st.selectbox("Dataset", available_datasets, index=0)
78
+ model_choice = st.selectbox(
79
+ "Model",
80
+ list(MODELS_CONFIG.keys()),
81
+ index=list(MODELS_CONFIG.keys()).index(DEFAULT_MODEL),
82
+ )
83
  steps_per_sample = st.slider("Max Steps", 1, 20, 10)
84
  temperature = st.slider(
85
  "Temperature",
 
95
  try:
96
  with open(data_paths["golden_labels"], "r") as f:
97
  golden_labels = json.load(f).get("samples", [])
98
+
99
  st.info(f"Dataset '{dataset_choice}' has {len(golden_labels)} samples")
100
  if len(golden_labels) == 0:
101
  st.error(f"Dataset '{dataset_choice}' contains no samples!")
102
  st.stop()
103
+
104
  except FileNotFoundError:
105
+ st.error(
106
+ f"❌ Dataset '{dataset_choice}' not found at {data_paths['golden_labels']}"
107
+ )
108
  st.info("💡 Available datasets: " + ", ".join(available_datasets))
109
  st.stop()
110
  except Exception as e:
 
116
  )
117
  else: # Online Mode
118
  st.info("Enter a URL to analyze a specific location")
119
+
120
  # Add example URLs
121
  example_google_url = "https://www.google.com/maps/@37.8728123,-122.2445339,3a,75y,3.36h,90t/data=!3m7!1e1!3m5!1s4DTABKOpCL6hdNRgnAHTgw!2e0!6shttps:%2F%2Fstreetviewpixels-pa.googleapis.com%2Fv1%2Fthumbnail%3Fcb_client%3Dmaps_sv.tactile%26w%3D900%26h%3D600%26pitch%3D0%26panoid%3D4DTABKOpCL6hdNRgnAHTgw%26yaw%3D3.3576431!7i13312!8i6656?entry=ttu"
122
+ example_mapcrunch_url = (
123
+ "http://www.mapcrunch.com/p/37.882284_-122.269626_293.91_-6.63_0"
124
+ )
125
+
126
  # Create tabs for different URL types
127
  input_tab1, input_tab2 = st.tabs(["Google Maps URL", "MapCrunch URL"])
128
+
129
  google_url = ""
130
  mapcrunch_url = ""
131
  golden_labels = None
132
  num_samples = None
133
+
134
  with input_tab1:
135
  url_col1, url_col2 = st.columns([3, 1])
136
  with url_col1:
 
139
  placeholder="https://www.google.com/maps/@37.5851338,-122.1519467,9z?entry=ttu",
140
  key="google_maps_url",
141
  )
142
+ st.markdown(
143
+ f"💡 **Example Location:** [View in Google Maps]({example_google_url})"
144
+ )
145
  if google_url:
146
  mapcrunch_url_converted = convert_google_to_mapcrunch_url(google_url)
147
  if mapcrunch_url_converted:
148
  st.success(f"Converted to MapCrunch URL: {mapcrunch_url_converted}")
149
  try:
150
+ match = re.search(r"@(-?\d+\.\d+),(-?\d+\.\d+)", google_url)
151
+ if not match:
152
+ st.error("Invalid Google Maps URL format")
153
+ st.stop()
154
+
155
+ lat, lon = match.groups()
156
+
157
+ golden_labels = [
158
+ {
159
+ "id": "online",
160
+ "lat": float(lat),
161
+ "lng": float(lon),
162
+ "url": mapcrunch_url_converted,
163
+ }
164
+ ]
165
  num_samples = 1
166
  except Exception as e:
167
  st.error(f"Invalid Google Maps URL format: {str(e)}")
168
  else:
169
  st.error("Invalid Google Maps URL format")
170
+
171
  with input_tab2:
172
  st.markdown("💡 **Example Location:**")
173
  st.markdown(f"[View in MapCrunch]({example_mapcrunch_url})")
174
  st.code(example_mapcrunch_url, language="text")
175
  mapcrunch_url = st.text_input(
176
+ "MapCrunch URL", placeholder=example_mapcrunch_url, key="mapcrunch_url"
 
 
177
  )
178
  if mapcrunch_url:
179
  try:
180
+ coords = mapcrunch_url.split("/")[-1].split("_")
181
  lat, lon = float(coords[0]), float(coords[1])
182
+ golden_labels = [
183
+ {"id": "online", "lat": lat, "lng": lon, "url": mapcrunch_url}
184
+ ]
 
 
 
185
  num_samples = 1
186
  except Exception as e:
187
  st.error(f"Invalid MapCrunch URL format: {str(e)}")
188
+
189
  # Only stop if neither input is provided
190
  if not google_url and not mapcrunch_url:
191
+ st.warning(
192
+ "Please enter a Google Maps URL or MapCrunch URL, or use the example above."
193
+ )
194
  st.stop()
195
  if golden_labels is None or num_samples is None:
196
  st.warning("Please enter a valid URL.")
197
  st.stop()
198
+
199
+ model_choice = st.selectbox(
200
+ "Model",
201
+ list(MODELS_CONFIG.keys()),
202
+ index=list(MODELS_CONFIG.keys()).index(DEFAULT_MODEL),
203
+ )
204
  steps_per_sample = st.slider("Max Steps", 1, 20, 10)
205
  temperature = st.slider(
206
  "Temperature",
 
219
  config = MODELS_CONFIG[model_choice]
220
  model_class = get_model_class(config["class"])
221
 
222
+ benchmark_helper = MapGuesserBenchmark(
223
+ dataset_name=dataset_choice if mode == "Dataset Mode" else "online"
224
+ )
225
  all_results = []
226
 
227
  progress_bar = st.progress(0)
 
242
  else:
243
  # Load from dataset as before
244
  bot.controller.load_location_from_data(sample)
245
+
246
  bot.controller.setup_clean_environment()
247
 
248
  # Create containers for UI updates
 
425
  mime="application/json",
426
  )
427
 
428
+
429
  def handle_tab_completion():
430
  """Handle tab completion for the Google Maps URL input."""
431
  if st.session_state.google_maps_url == "":
432
+ st.session_state.google_maps_url = (
433
+ "https://www.google.com/maps/@37.5851338,-122.1519467,9z?entry=ttu"
434
+ )
benchmark.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
  import math
10
 
11
  from geo_bot import GeoBot
12
- from config import get_data_paths, MODELS_CONFIG, SUCCESS_THRESHOLD_KM
13
 
14
 
15
  class MapGuesserBenchmark:
@@ -29,25 +29,6 @@ class MapGuesserBenchmark:
29
  except Exception:
30
  return []
31
 
32
- def get_model_class(self, model_name: str):
33
- config = MODELS_CONFIG.get(model_name)
34
- if not config:
35
- raise ValueError(f"Unknown model: {model_name}")
36
- class_name, model_class_name = config["class"], config["model_name"]
37
- if class_name == "ChatOpenAI":
38
- from langchain_openai import ChatOpenAI
39
-
40
- return ChatOpenAI, model_class_name
41
- if class_name == "ChatAnthropic":
42
- from langchain_anthropic import ChatAnthropic
43
-
44
- return ChatAnthropic, model_class_name
45
- if class_name == "ChatGoogleGenerativeAI":
46
- from langchain_google_genai import ChatGoogleGenerativeAI
47
-
48
- return ChatGoogleGenerativeAI, model_class_name
49
- raise ValueError(f"Unknown model class: {class_name}")
50
-
51
  def calculate_distance(
52
  self, true_coords: Dict, predicted_coords: Optional[Tuple[float, float]]
53
  ) -> Optional[float]:
@@ -99,7 +80,9 @@ class MapGuesserBenchmark:
99
  all_results = []
100
  for model_name in models_to_test:
101
  print(f"\n🤖 Testing model: {model_name}")
102
- model_class, model_class_name = self.get_model_class(model_name)
 
 
103
 
104
  try:
105
  with GeoBot(
 
9
  import math
10
 
11
  from geo_bot import GeoBot
12
+ from config import get_data_paths, MODELS_CONFIG, SUCCESS_THRESHOLD_KM, get_model_class
13
 
14
 
15
  class MapGuesserBenchmark:
 
29
  except Exception:
30
  return []
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def calculate_distance(
33
  self, true_coords: Dict, predicted_coords: Optional[Tuple[float, float]]
34
  ) -> Optional[float]:
 
80
  all_results = []
81
  for model_name in models_to_test:
82
  print(f"\n🤖 Testing model: {model_name}")
83
+ model_config = MODELS_CONFIG[model_name]
84
+ model_class = get_model_class(model_config["class"])
85
+ model_class_name = model_config["model_name"]
86
 
87
  try:
88
  with GeoBot(
config.py CHANGED
@@ -1,5 +1,10 @@
1
  # Configuration file for MapCrunch benchmark
2
 
 
 
 
 
 
3
  SUCCESS_THRESHOLD_KM = 100
4
 
5
  # MapCrunch settings
@@ -42,10 +47,15 @@ MODELS_CONFIG = {
42
  "model_name": "gpt-4o-mini",
43
  "description": "OpenAI GPT-4o Mini",
44
  },
45
- "claude-3.5-sonnet": {
 
 
 
 
 
46
  "class": "ChatAnthropic",
47
- "model_name": "claude-3-5-sonnet-20240620",
48
- "description": "Anthropic Claude 3.5 Sonnet",
49
  },
50
  "gemini-1.5-pro": {
51
  "class": "ChatGoogleGenerativeAI",
@@ -62,18 +72,93 @@ MODELS_CONFIG = {
62
  "model_name": "gemini-2.5-pro-preview-06-05",
63
  "description": "Google Gemini 2.5 Pro",
64
  },
65
- "qwen2-vl-7b": {
66
- "class": "HuggingFaceChat",
67
- "model_name": "Qwen/Qwen2-VL-7B-Instruct",
68
- "description": "Qwen2-VL 7B (older but API supported)",
 
 
 
 
 
 
 
 
 
 
69
  },
70
- "qwen2-vl-2b": {
71
- "class": "HuggingFaceChat",
72
- "model_name": "Qwen/Qwen2-VL-2B-Instruct",
73
- "description": "Qwen2-VL 2B (faster, API supported)",
74
  },
75
  }
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Data paths - now supports named datasets
78
  def get_data_paths(dataset_name: str = "default"):
79
  """Get data paths for a specific dataset"""
@@ -83,5 +168,6 @@ def get_data_paths(dataset_name: str = "default"):
83
  "results": f"results/{dataset_name}/",
84
  }
85
 
 
86
  # Backward compatibility - default paths
87
  DATA_PATHS = get_data_paths("default")
 
1
  # Configuration file for MapCrunch benchmark
2
 
3
+ from pydantic import SecretStr, Field
4
+ from typing import Optional
5
+ import os
6
+
7
+
8
  SUCCESS_THRESHOLD_KM = 100
9
 
10
  # MapCrunch settings
 
47
  "model_name": "gpt-4o-mini",
48
  "description": "OpenAI GPT-4o Mini",
49
  },
50
+ "claude-3-7-sonnet": {
51
+ "class": "ChatAnthropic",
52
+ "model_name": "claude-3-7-sonnet-20250219",
53
+ "description": "Anthropic Claude 3.7 Sonnet",
54
+ },
55
+ "claude-4-sonnet": {
56
  "class": "ChatAnthropic",
57
+ "model_name": "claude-4-sonnet-20250514",
58
+ "description": "Anthropic Claude 4 Sonnet",
59
  },
60
  "gemini-1.5-pro": {
61
  "class": "ChatGoogleGenerativeAI",
 
72
  "model_name": "gemini-2.5-pro-preview-06-05",
73
  "description": "Google Gemini 2.5 Pro",
74
  },
75
+ "qwen-vl-max": {
76
+ "class": "OpenRouter",
77
+ "model_name": "qwen/qwen-vl-max",
78
+ "description": "Qwen VL Max - OpenRouter (Best Performance)",
79
+ },
80
+ "qwen2.5-vl-32b-free": {
81
+ "class": "OpenRouter",
82
+ "model_name": "qwen/qwen2.5-vl-32b-instruct:free",
83
+ "description": "Qwen2.5 VL 32B - OpenRouter (FREE!)",
84
+ },
85
+ "qwen2.5-vl-7b": {
86
+ "class": "OpenRouter",
87
+ "model_name": "qwen/qwen2.5-vl-7b-instruct",
88
+ "description": "Qwen2.5 VL 7B - OpenRouter",
89
  },
90
+ "qwen2.5-vl-3b": {
91
+ "class": "OpenRouter",
92
+ "model_name": "qwen/qwen2.5-vl-3b-instruct",
93
+ "description": "Qwen2.5 VL 3B - OpenRouter (Fastest)",
94
  },
95
  }
96
 
97
+ POSSIBLE_API_KEYS = [
98
+ "OPENAI_API_KEY",
99
+ "ANTHROPIC_API_KEY",
100
+ "GOOGLE_API_KEY",
101
+ "HF_TOKEN",
102
+ "OPENROUTER_API_KEY",
103
+ ]
104
+
105
+
106
+ def setup_environment_variables(st_secrets=None):
107
+ for key in POSSIBLE_API_KEYS:
108
+ # Try Streamlit secrets first if provided
109
+ if st_secrets and key in st_secrets:
110
+ os.environ[key] = st_secrets[key]
111
+ elif key in os.environ:
112
+ continue
113
+
114
+
115
+ def get_model_class(class_name):
116
+ """Get actual model class from string name"""
117
+ if class_name == "ChatOpenAI":
118
+ from langchain_openai import ChatOpenAI
119
+
120
+ return ChatOpenAI
121
+ elif class_name == "ChatAnthropic":
122
+ from langchain_anthropic import ChatAnthropic
123
+
124
+ return ChatAnthropic
125
+ elif class_name == "ChatGoogleGenerativeAI":
126
+ from langchain_google_genai import ChatGoogleGenerativeAI
127
+
128
+ return ChatGoogleGenerativeAI
129
+ elif class_name == "HuggingFaceChat":
130
+ from hf_chat import HuggingFaceChat
131
+
132
+ return HuggingFaceChat
133
+ elif class_name == "OpenRouter":
134
+ from langchain_openai import ChatOpenAI
135
+ from langchain_core.utils.utils import secret_from_env
136
+
137
+ # LangChain does not support OpenRouter directly, so we need to create a custom class
138
+ # See https://github.com/langchain-ai/langchain/discussions/27964.
139
+ class ChatOpenRouter(ChatOpenAI):
140
+ openai_api_key: Optional[SecretStr] = Field(
141
+ alias="api_key",
142
+ default_factory=secret_from_env("OPENROUTER_API_KEY", default=None),
143
+ )
144
+
145
+ @property
146
+ def lc_secrets(self) -> dict[str, str]:
147
+ return {"openai_api_key": "OPENROUTER_API_KEY"}
148
+
149
+ def __init__(self, openai_api_key: Optional[str] = None, **kwargs):
150
+ openai_api_key = openai_api_key or os.environ.get("OPENROUTER_API_KEY")
151
+ super().__init__(
152
+ base_url="https://openrouter.ai/api/v1",
153
+ api_key=SecretStr(openai_api_key) if openai_api_key else None,
154
+ **kwargs,
155
+ )
156
+
157
+ return ChatOpenRouter
158
+ else:
159
+ raise ValueError(f"Unknown model class: {class_name}")
160
+
161
+
162
  # Data paths - now supports named datasets
163
  def get_data_paths(dataset_name: str = "default"):
164
  """Get data paths for a specific dataset"""
 
168
  "results": f"results/{dataset_name}/",
169
  }
170
 
171
+
172
  # Backward compatibility - default paths
173
  DATA_PATHS = get_data_paths("default")
geo_bot.py CHANGED
@@ -6,13 +6,7 @@ from typing import Tuple, List, Optional, Dict, Any, Type
6
 
7
  from PIL import Image
8
  from langchain_core.messages import HumanMessage, BaseMessage
9
- from langchain_core.language_models.chat_models import BaseChatModel
10
- from langchain_openai import ChatOpenAI
11
- from langchain_anthropic import ChatAnthropic
12
- from langchain_google_genai import ChatGoogleGenerativeAI
13
-
14
  from hf_chat import HuggingFaceChat
15
-
16
  from mapcrunch_controller import MapCrunchController
17
 
18
  # The "Golden" Prompt (v7): add more descprtions in context and task
 
6
 
7
  from PIL import Image
8
  from langchain_core.messages import HumanMessage, BaseMessage
 
 
 
 
 
9
  from hf_chat import HuggingFaceChat
 
10
  from mapcrunch_controller import MapCrunchController
11
 
12
  # The "Golden" Prompt (v7): add more descprtions in context and task
main.py CHANGED
@@ -1,16 +1,10 @@
1
  import argparse
2
  import json
3
- import random
4
- from typing import Dict, Optional, List
5
-
6
- from langchain_openai import ChatOpenAI
7
- from langchain_anthropic import ChatAnthropic
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
 
10
  from geo_bot import GeoBot
11
  from benchmark import MapGuesserBenchmark
12
  from data_collector import DataCollector
13
- from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM
14
 
15
 
16
  def agent_mode(
@@ -48,7 +42,7 @@ def agent_mode(
48
  print(f"Will run on {len(test_samples)} samples from dataset '{dataset_name}'.")
49
 
50
  config = MODELS_CONFIG.get(model_name)
51
- model_class = globals()[config["class"]]
52
  model_instance_name = config["model_name"]
53
 
54
  benchmark_helper = MapGuesserBenchmark(dataset_name=dataset_name, headless=True)
 
1
  import argparse
2
  import json
 
 
 
 
 
 
3
 
4
  from geo_bot import GeoBot
5
  from benchmark import MapGuesserBenchmark
6
  from data_collector import DataCollector
7
+ from config import MODELS_CONFIG, get_data_paths, SUCCESS_THRESHOLD_KM, get_model_class
8
 
9
 
10
  def agent_mode(
 
42
  print(f"Will run on {len(test_samples)} samples from dataset '{dataset_name}'.")
43
 
44
  config = MODELS_CONFIG.get(model_name)
45
+ model_class = get_model_class(config["class"])
46
  model_instance_name = config["model_name"]
47
 
48
  benchmark_helper = MapGuesserBenchmark(dataset_name=dataset_name, headless=True)