David de la Iglesia Castro commited on
Commit
8aa7d2f
·
unverified ·
1 Parent(s): 502c6d3

Add "vertical" tools (#20)

Browse files

* Add `tools` module.

- Include `openmeteo` and `openstreetmap`.

* Expose "vertical" tools to openai agent.

* fix(test_unit_openai): Drop outdated tool assert

docs/api.md CHANGED
@@ -10,10 +10,13 @@
10
 
11
  ::: surf_spot_finder.agents.smolagents
12
 
 
 
 
 
 
13
  ::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
14
 
15
  ::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
16
 
17
  ::: surf_spot_finder.prompts.shared.INPUT_PROMPT
18
-
19
- ::: surf_spot_finder.tracing
 
10
 
11
  ::: surf_spot_finder.agents.smolagents
12
 
13
+ ::: surf_spot_finder.tools.openmeteo
14
+ ::: surf_spot_finder.tools.openstreetmap
15
+
16
+ ::: surf_spot_finder.tracing
17
+
18
  ::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
19
 
20
  ::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
21
 
22
  ::: surf_spot_finder.prompts.shared.INPUT_PROMPT
 
 
src/surf_spot_finder/agents/openai.py CHANGED
@@ -21,6 +21,18 @@ from surf_spot_finder.prompts.openai import (
21
  SINGLE_AGENT_SYSTEM_PROMPT,
22
  MULTI_AGENT_SYSTEM_PROMPT,
23
  )
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  @function_tool
@@ -125,7 +137,15 @@ def run_openai_agent(
125
  model=model_id,
126
  instructions=instructions,
127
  name=name,
128
- tools=[search_web, visit_webpage],
 
 
 
 
 
 
 
 
129
  )
130
  result = Runner.run_sync(agent, prompt)
131
  logger.info(result.final_output)
 
21
  SINGLE_AGENT_SYSTEM_PROMPT,
22
  MULTI_AGENT_SYSTEM_PROMPT,
23
  )
24
+ from surf_spot_finder.tools.openmeteo import get_wave_forecast, get_wind_forecast
25
+ from surf_spot_finder.tools.openstreetmap import (
26
+ driving_hours_to_meters,
27
+ get_area_lat_lon,
28
+ get_surfing_places,
29
+ )
30
+
31
+ driving_hours_to_meters = function_tool(driving_hours_to_meters)
32
+ get_area_lat_lon = function_tool(get_area_lat_lon)
33
+ get_surfing_places = function_tool(get_surfing_places)
34
+ get_wave_forecast = function_tool(get_wave_forecast)
35
+ get_wind_forecast = function_tool(get_wind_forecast)
36
 
37
 
38
  @function_tool
 
137
  model=model_id,
138
  instructions=instructions,
139
  name=name,
140
+ tools=[
141
+ search_web,
142
+ visit_webpage,
143
+ get_area_lat_lon,
144
+ get_surfing_places,
145
+ get_wave_forecast,
146
+ get_wind_forecast,
147
+ driving_hours_to_meters,
148
+ ],
149
  )
150
  result = Runner.run_sync(agent, prompt)
151
  logger.info(result.final_output)
src/surf_spot_finder/prompts/openai.py CHANGED
@@ -3,6 +3,8 @@ You will be asked to perform a task.
3
 
4
  Before solving the task, plan a sequence of actions using the available tools.
5
  Then, execute the sequence of actions using the tools.
 
 
6
  """.strip()
7
 
8
  MULTI_AGENT_SYSTEM_PROMPT = """
 
3
 
4
  Before solving the task, plan a sequence of actions using the available tools.
5
  Then, execute the sequence of actions using the tools.
6
+
7
+ Prefer to use task-specific tools before relying on generic tools like web search.
8
  """.strip()
9
 
10
  MULTI_AGENT_SYSTEM_PROMPT = """
src/surf_spot_finder/tools/__init__.py ADDED
File without changes
src/surf_spot_finder/tools/openmeteo.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime, timedelta
3
+ import requests
4
+
5
+
6
+ def _extract_hourly_data(data: dict) -> list[dict]:
7
+ hourly_data = data["hourly"]
8
+ result = [
9
+ {k: v for k, v in zip(hourly_data.keys(), values)}
10
+ for values in zip(*hourly_data.values())
11
+ ]
12
+ return result
13
+
14
+
15
+ def _filter_by_date(
16
+ date: datetime, hourly_data: list[dict], timedelta: timedelta = timedelta(hours=1)
17
+ ):
18
+ start_date = date - timedelta
19
+ end_date = date + timedelta
20
+ return [
21
+ item
22
+ for item in hourly_data
23
+ if start_date <= datetime.fromisoformat(item["time"]) <= end_date
24
+ ]
25
+
26
+
27
+ def get_wave_forecast(lat: float, lon: float, date: str | None = None) -> list[dict]:
28
+ """Get wave forecast for given location.
29
+
30
+ Forecast will include:
31
+
32
+ - wave_direction (degrees)
33
+ - wave_height (meters)
34
+ - wave_period (seconds)
35
+ - sea_level_height_msl (meters)
36
+
37
+ Args:
38
+ lat (float): Latitude of the location.
39
+ lon (float): Longitude of the location.
40
+ date (str | None): Date to filter by in any valid ISO 8601 format.
41
+ If not provided, all data (default to 6 days forecast) will be returned.
42
+
43
+ Returns:
44
+ list[dict]: Hourly data for wave forecast.
45
+ Example output:
46
+
47
+ ```json
48
+ [
49
+ {'time': '2025-03-19T09:00', 'winddirection_10m': 140, 'windspeed_10m': 24.5}, {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1},
50
+ {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1}, {'time': '2025-03-19T11:00', 'winddirection_10m': 141, 'windspeed_10m': 29.2}
51
+ ]
52
+ ```
53
+ """
54
+ url = "https://marine-api.open-meteo.com/v1/marine"
55
+ params = {
56
+ "latitude": lat,
57
+ "longitude": lon,
58
+ "hourly": [
59
+ "wave_direction",
60
+ "wave_height",
61
+ "wave_period",
62
+ "sea_level_height_msl",
63
+ ],
64
+ }
65
+ response = requests.get(url, params=params)
66
+ response.raise_for_status()
67
+ data = json.loads(response.content.decode())
68
+ hourly_data = _extract_hourly_data(data)
69
+ if date is not None:
70
+ date = datetime.fromisoformat(date)
71
+ hourly_data = _filter_by_date(date, hourly_data)
72
+ return hourly_data
73
+
74
+
75
+ def get_wind_forecast(lat: float, lon: float, date: str | None = None) -> list[dict]:
76
+ """Get wind forecast for given location.
77
+
78
+ Forecast will include:
79
+
80
+ - wind_direction (degrees)
81
+ - wind_speed (meters per second)
82
+
83
+ Args:
84
+ lat (float): Latitude of the location.
85
+ lon (float): Longitude of the location.
86
+ date (str | None): Date to filter by in any valid ISO 8601 format.
87
+ If not provided, all data (default to 6 days forecast) will be returned.
88
+
89
+ Returns:
90
+ list[dict]: Hourly data for wind forecast.
91
+ Example output:
92
+
93
+ ```json
94
+ [
95
+ {"time": "2025-03-18T22:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.45, "sea_level_height_msl": -1.27},
96
+ {"time": "2025-03-18T23:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.35, "sea_level_height_msl": -1.35},
97
+ ]
98
+ ```
99
+ """
100
+ url = "https://api.open-meteo.com/v1/forecast"
101
+ params = {
102
+ "latitude": lat,
103
+ "longitude": lon,
104
+ "hourly": ["winddirection_10m", "windspeed_10m"],
105
+ }
106
+ response = requests.get(url, params=params)
107
+ response.raise_for_status()
108
+ data = json.loads(response.content.decode())
109
+ hourly_data = _extract_hourly_data(data)
110
+ if date is not None:
111
+ date = datetime.fromisoformat(date)
112
+ hourly_data = _filter_by_date(date, hourly_data)
113
+ return hourly_data
src/surf_spot_finder/tools/openstreetmap.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+
5
+ def get_area_lat_lon(area_name: str) -> tuple[float, float]:
6
+ """Get the latitude and longitude of an area from Nominatim.
7
+
8
+ Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
9
+
10
+ Args:
11
+ area_name (str): The name of the area.
12
+
13
+ Returns:
14
+ dict: The area found.
15
+ """
16
+ response = requests.get(
17
+ f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
18
+ headers={"User-Agent": "Mozilla/5.0"},
19
+ )
20
+ response.raise_for_status()
21
+ area = json.loads(response.content.decode())
22
+ return area[0]["lat"], area[0]["lon"]
23
+
24
+
25
+ def driving_hours_to_meters(driving_hours: int) -> int:
26
+ """Convert driving hours to meters assuming a 70 km/h average speed.
27
+
28
+
29
+ Args:
30
+ driving_hours (int): The driving hours.
31
+
32
+ Returns:
33
+ int: The distance in meters.
34
+ """
35
+ return driving_hours * 70 * 1000
36
+
37
+
38
+ def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
39
+ """Get the latitude and longitude of the center of a bounding box.
40
+
41
+ Args:
42
+ bounds (dict): The bounding box.
43
+
44
+ ```json
45
+ {
46
+ "minlat": float,
47
+ "minlon": float,
48
+ "maxlat": float,
49
+ "maxlon": float,
50
+ }
51
+ ```
52
+
53
+ Returns:
54
+ tuple: The latitude and longitude of the center.
55
+ """
56
+ return (
57
+ (bounds["minlat"] + bounds["maxlat"]) / 2,
58
+ (bounds["minlon"] + bounds["maxlon"]) / 2,
59
+ )
60
+
61
+
62
+ def get_surfing_places(
63
+ lat: float, lon: float, radius: int
64
+ ) -> list[tuple[str, tuple[float, float]]]:
65
+ """Get surfing places around a given latitude and longitude.
66
+
67
+ Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
68
+
69
+ Args:
70
+ lat (float): The latitude.
71
+ lon (float): The longitude.
72
+ radius (int): The radius in meters.
73
+
74
+ Returns:
75
+ dict: The surfing places found.
76
+ """
77
+ overpass_url = "https://overpass-api.de/api/interpreter"
78
+ query = "[out:json];("
79
+ query += f'nwr["natural"="beach"](around:{radius},{lat},{lon});'
80
+ query += f'nwr["natural"="reef"](around:{radius},{lat},{lon});'
81
+ query += ");out body geom;"
82
+ params = {"data": query}
83
+ response = requests.get(
84
+ overpass_url, params=params, headers={"User-Agent": "Mozilla/5.0"}
85
+ )
86
+ response.raise_for_status()
87
+ elements = response.json()["elements"]
88
+ return [
89
+ (element.get("tags", {}).get("name", ""), get_lat_lon_center(element["bounds"]))
90
+ for element in elements
91
+ if "surfing" in element.get("tags", {}).get("sport", "")
92
+ ]
tests/unit/agents/test_unit_openai.py CHANGED
@@ -28,7 +28,7 @@ def test_run_openai_agent_default():
28
  model="gpt-4o",
29
  instructions=SINGLE_AGENT_SYSTEM_PROMPT,
30
  name="surf-spot-finder",
31
- tools=[search_web, visit_webpage],
32
  )
33
 
34
 
 
28
  model="gpt-4o",
29
  instructions=SINGLE_AGENT_SYSTEM_PROMPT,
30
  name="surf-spot-finder",
31
+ tools=ANY,
32
  )
33
 
34
 
tests/unit/tools/test_unit_openmeteo.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime, timedelta
3
+ from unittest.mock import patch, MagicMock
4
+
5
+ from surf_spot_finder.tools import openmeteo
6
+
7
+
8
+ def test_extract_hourly_data():
9
+ data = {
10
+ "hourly": {
11
+ "time": ["2023-01-01T00:00", "2023-01-01T01:00"],
12
+ "wave_height": [1.5, 1.6],
13
+ "wave_period": [10, 11],
14
+ }
15
+ }
16
+ expected = [
17
+ {"time": "2023-01-01T00:00", "wave_height": 1.5, "wave_period": 10},
18
+ {"time": "2023-01-01T01:00", "wave_height": 1.6, "wave_period": 11},
19
+ ]
20
+ assert openmeteo._extract_hourly_data(data) == expected
21
+
22
+
23
+ def test_filter_by_date():
24
+ hourly_data = [
25
+ {"time": "2023-01-01T00:00", "wave_height": 1.5},
26
+ {"time": "2023-01-01T01:00", "wave_height": 1.6},
27
+ {"time": "2023-01-01T02:00", "wave_height": 1.7},
28
+ {"time": "2023-01-01T03:00", "wave_height": 1.8},
29
+ ]
30
+ date = datetime.fromisoformat("2023-01-01T01:00")
31
+ expected = [
32
+ {"time": "2023-01-01T00:00", "wave_height": 1.5},
33
+ {"time": "2023-01-01T01:00", "wave_height": 1.6},
34
+ {"time": "2023-01-01T02:00", "wave_height": 1.7},
35
+ ]
36
+ assert openmeteo._filter_by_date(date, hourly_data) == expected
37
+
38
+ expected = [
39
+ {"time": "2023-01-01T01:00", "wave_height": 1.6},
40
+ ]
41
+ assert openmeteo._filter_by_date(date, hourly_data, timedelta(hours=0)) == expected
42
+
43
+
44
+ def test_get_wave_forecast():
45
+ with patch("requests.get") as mock_get:
46
+ mock_response = MagicMock()
47
+ mock_response.status_code = 200
48
+ mock_response.content.decode.return_value = json.dumps(
49
+ {
50
+ "hourly": {
51
+ "time": ["2023-02-02T00:00", "2023-02-02T01:00"],
52
+ "wave_direction": [270, 280],
53
+ "wave_height": [1.5, 1.6],
54
+ "wave_period": [10, 11],
55
+ "sea_level_height_msl": [0.5, 0.6],
56
+ }
57
+ }
58
+ )
59
+ mock_get.return_value = mock_response
60
+ result = openmeteo.get_wave_forecast(lat=40.0, lon=-3.0)
61
+
62
+ assert len(result) == 2
63
+ assert result[1]["time"] == "2023-02-02T01:00"
64
+ assert result[1]["wave_direction"] == 280
65
+ assert result[1]["wave_height"] == 1.6
66
+ assert result[1]["wave_period"] == 11
67
+ assert result[1]["sea_level_height_msl"] == 0.6
68
+
69
+ result_filtered = openmeteo.get_wave_forecast(
70
+ lat=40.0, lon=-3.0, date="2023-02-02T02:00"
71
+ )
72
+ assert len(result_filtered) == 1
73
+ assert result_filtered[0]["time"] == "2023-02-02T01:00"
74
+
75
+
76
+ def test_get_wind_forecast():
77
+ with patch("requests.get") as mock_get:
78
+ mock_response = MagicMock()
79
+ mock_response.status_code = 200
80
+ mock_response.content.decode.return_value = json.dumps(
81
+ {
82
+ "hourly": {
83
+ "time": ["2023-02-02T00:00", "2023-02-02T01:00"],
84
+ "winddirection_10m": [270, 280],
85
+ "windspeed_10m": [10, 11],
86
+ }
87
+ }
88
+ )
89
+ mock_get.return_value = mock_response
90
+
91
+ result = openmeteo.get_wind_forecast(lat=40.0, lon=-3.0)
92
+ assert len(result) == 2
93
+ assert result[1]["time"] == "2023-02-02T01:00"
94
+ assert result[1]["winddirection_10m"] == 280
95
+ assert result[1]["windspeed_10m"] == 11
96
+
97
+ result_filtered = openmeteo.get_wind_forecast(
98
+ lat=40.0, lon=-3.0, date="2023-02-02T02:00"
99
+ )
100
+ assert len(result_filtered) == 1
101
+ assert result_filtered[0]["time"] == "2023-02-02T01:00"
102
+ assert result[0]["windspeed_10m"] == 10
tests/unit/tools/test_unit_openstreetmap.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ from surf_spot_finder.tools import openstreetmap
5
+
6
+
7
+ def test_get_area_lat_lon():
8
+ with patch("requests.get") as mock_get:
9
+ mock_response = MagicMock()
10
+ mock_response.status_code = 200
11
+ mock_response.content.decode.return_value = json.dumps(
12
+ [{"lat": "40.0", "lon": "-3.0"}]
13
+ )
14
+ mock_get.return_value = mock_response
15
+
16
+ lat, lon = openstreetmap.get_area_lat_lon("Madrid")
17
+ assert lat == "40.0"
18
+ assert lon == "-3.0"
19
+
20
+
21
+ def test_driving_hours_to_meters():
22
+ assert openstreetmap.driving_hours_to_meters(1) == 70000
23
+
24
+
25
+ def test_get_lat_lon_center():
26
+ bounds = {"minlat": 40.0, "minlon": -3.0, "maxlat": 41.0, "maxlon": -2.0}
27
+ lat, lon = openstreetmap.get_lat_lon_center(bounds)
28
+ assert lat == 40.5
29
+ assert lon == -2.5
30
+
31
+
32
+ def test_get_surfing_places():
33
+ with patch("requests.get") as mock_get:
34
+ mock_response = MagicMock()
35
+ mock_response.status_code = 200
36
+ mock_response.json.return_value = {
37
+ "elements": [
38
+ {
39
+ "tags": {"name": "Surf Spot 1", "sport": "surfing"},
40
+ "bounds": {
41
+ "minlat": 40.0,
42
+ "minlon": -3.0,
43
+ "maxlat": 40.1,
44
+ "maxlon": -2.9,
45
+ },
46
+ },
47
+ {
48
+ "tags": {"name": "Beach 2", "sport": "swimming"},
49
+ "bounds": {
50
+ "minlat": 41.0,
51
+ "minlon": -4.0,
52
+ "maxlat": 41.1,
53
+ "maxlon": -3.9,
54
+ },
55
+ },
56
+ {
57
+ "tags": {"name": "Surf Spot 3", "sport": "surfing"},
58
+ "bounds": {
59
+ "minlat": 42.0,
60
+ "minlon": -5.0,
61
+ "maxlat": 42.1,
62
+ "maxlon": -4.9,
63
+ },
64
+ },
65
+ ]
66
+ }
67
+ mock_get.return_value = mock_response
68
+
69
+ results = openstreetmap.get_surfing_places(lat=40.5, lon=-3.5, radius=10000)
70
+ assert len(results) == 2
71
+ assert results[0][0] == "Surf Spot 1"
72
+ assert results[0][1] == (40.05, -2.95)
73
+ assert results[1][0] == "Surf Spot 3"
74
+ assert results[1][1] == (42.05, -4.95)