Spaces:
Running
Running
David de la Iglesia Castro
commited on
Add "vertical" tools (#20)
Browse files* Add `tools` module.
- Include `openmeteo` and `openstreetmap`.
* Expose "vertical" tools to openai agent.
* fix(test_unit_openai): Drop outdated tool assert
- docs/api.md +5 -2
- src/surf_spot_finder/agents/openai.py +21 -1
- src/surf_spot_finder/prompts/openai.py +2 -0
- src/surf_spot_finder/tools/__init__.py +0 -0
- src/surf_spot_finder/tools/openmeteo.py +113 -0
- src/surf_spot_finder/tools/openstreetmap.py +92 -0
- tests/unit/agents/test_unit_openai.py +1 -1
- tests/unit/tools/test_unit_openmeteo.py +102 -0
- tests/unit/tools/test_unit_openstreetmap.py +74 -0
docs/api.md
CHANGED
@@ -10,10 +10,13 @@
|
|
10 |
|
11 |
::: surf_spot_finder.agents.smolagents
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
|
14 |
|
15 |
::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
|
16 |
|
17 |
::: surf_spot_finder.prompts.shared.INPUT_PROMPT
|
18 |
-
|
19 |
-
::: surf_spot_finder.tracing
|
|
|
10 |
|
11 |
::: surf_spot_finder.agents.smolagents
|
12 |
|
13 |
+
::: surf_spot_finder.tools.openmeteo
|
14 |
+
::: surf_spot_finder.tools.openstreetmap
|
15 |
+
|
16 |
+
::: surf_spot_finder.tracing
|
17 |
+
|
18 |
::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
|
19 |
|
20 |
::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
|
21 |
|
22 |
::: surf_spot_finder.prompts.shared.INPUT_PROMPT
|
|
|
|
src/surf_spot_finder/agents/openai.py
CHANGED
@@ -21,6 +21,18 @@ from surf_spot_finder.prompts.openai import (
|
|
21 |
SINGLE_AGENT_SYSTEM_PROMPT,
|
22 |
MULTI_AGENT_SYSTEM_PROMPT,
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
@function_tool
|
@@ -125,7 +137,15 @@ def run_openai_agent(
|
|
125 |
model=model_id,
|
126 |
instructions=instructions,
|
127 |
name=name,
|
128 |
-
tools=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
)
|
130 |
result = Runner.run_sync(agent, prompt)
|
131 |
logger.info(result.final_output)
|
|
|
21 |
SINGLE_AGENT_SYSTEM_PROMPT,
|
22 |
MULTI_AGENT_SYSTEM_PROMPT,
|
23 |
)
|
24 |
+
from surf_spot_finder.tools.openmeteo import get_wave_forecast, get_wind_forecast
|
25 |
+
from surf_spot_finder.tools.openstreetmap import (
|
26 |
+
driving_hours_to_meters,
|
27 |
+
get_area_lat_lon,
|
28 |
+
get_surfing_places,
|
29 |
+
)
|
30 |
+
|
31 |
+
driving_hours_to_meters = function_tool(driving_hours_to_meters)
|
32 |
+
get_area_lat_lon = function_tool(get_area_lat_lon)
|
33 |
+
get_surfing_places = function_tool(get_surfing_places)
|
34 |
+
get_wave_forecast = function_tool(get_wave_forecast)
|
35 |
+
get_wind_forecast = function_tool(get_wind_forecast)
|
36 |
|
37 |
|
38 |
@function_tool
|
|
|
137 |
model=model_id,
|
138 |
instructions=instructions,
|
139 |
name=name,
|
140 |
+
tools=[
|
141 |
+
search_web,
|
142 |
+
visit_webpage,
|
143 |
+
get_area_lat_lon,
|
144 |
+
get_surfing_places,
|
145 |
+
get_wave_forecast,
|
146 |
+
get_wind_forecast,
|
147 |
+
driving_hours_to_meters,
|
148 |
+
],
|
149 |
)
|
150 |
result = Runner.run_sync(agent, prompt)
|
151 |
logger.info(result.final_output)
|
src/surf_spot_finder/prompts/openai.py
CHANGED
@@ -3,6 +3,8 @@ You will be asked to perform a task.
|
|
3 |
|
4 |
Before solving the task, plan a sequence of actions using the available tools.
|
5 |
Then, execute the sequence of actions using the tools.
|
|
|
|
|
6 |
""".strip()
|
7 |
|
8 |
MULTI_AGENT_SYSTEM_PROMPT = """
|
|
|
3 |
|
4 |
Before solving the task, plan a sequence of actions using the available tools.
|
5 |
Then, execute the sequence of actions using the tools.
|
6 |
+
|
7 |
+
Prefer to use task-specific tools before relying on generic tools like web search.
|
8 |
""".strip()
|
9 |
|
10 |
MULTI_AGENT_SYSTEM_PROMPT = """
|
src/surf_spot_finder/tools/__init__.py
ADDED
File without changes
|
src/surf_spot_finder/tools/openmeteo.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
import requests
|
4 |
+
|
5 |
+
|
6 |
+
def _extract_hourly_data(data: dict) -> list[dict]:
|
7 |
+
hourly_data = data["hourly"]
|
8 |
+
result = [
|
9 |
+
{k: v for k, v in zip(hourly_data.keys(), values)}
|
10 |
+
for values in zip(*hourly_data.values())
|
11 |
+
]
|
12 |
+
return result
|
13 |
+
|
14 |
+
|
15 |
+
def _filter_by_date(
|
16 |
+
date: datetime, hourly_data: list[dict], timedelta: timedelta = timedelta(hours=1)
|
17 |
+
):
|
18 |
+
start_date = date - timedelta
|
19 |
+
end_date = date + timedelta
|
20 |
+
return [
|
21 |
+
item
|
22 |
+
for item in hourly_data
|
23 |
+
if start_date <= datetime.fromisoformat(item["time"]) <= end_date
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
def get_wave_forecast(lat: float, lon: float, date: str | None = None) -> list[dict]:
|
28 |
+
"""Get wave forecast for given location.
|
29 |
+
|
30 |
+
Forecast will include:
|
31 |
+
|
32 |
+
- wave_direction (degrees)
|
33 |
+
- wave_height (meters)
|
34 |
+
- wave_period (seconds)
|
35 |
+
- sea_level_height_msl (meters)
|
36 |
+
|
37 |
+
Args:
|
38 |
+
lat (float): Latitude of the location.
|
39 |
+
lon (float): Longitude of the location.
|
40 |
+
date (str | None): Date to filter by in any valid ISO 8601 format.
|
41 |
+
If not provided, all data (default to 6 days forecast) will be returned.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
list[dict]: Hourly data for wave forecast.
|
45 |
+
Example output:
|
46 |
+
|
47 |
+
```json
|
48 |
+
[
|
49 |
+
{'time': '2025-03-19T09:00', 'winddirection_10m': 140, 'windspeed_10m': 24.5}, {'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1},
|
50 |
+
{'time': '2025-03-19T10:00', 'winddirection_10m': 140, 'windspeed_10m': 27.1}, {'time': '2025-03-19T11:00', 'winddirection_10m': 141, 'windspeed_10m': 29.2}
|
51 |
+
]
|
52 |
+
```
|
53 |
+
"""
|
54 |
+
url = "https://marine-api.open-meteo.com/v1/marine"
|
55 |
+
params = {
|
56 |
+
"latitude": lat,
|
57 |
+
"longitude": lon,
|
58 |
+
"hourly": [
|
59 |
+
"wave_direction",
|
60 |
+
"wave_height",
|
61 |
+
"wave_period",
|
62 |
+
"sea_level_height_msl",
|
63 |
+
],
|
64 |
+
}
|
65 |
+
response = requests.get(url, params=params)
|
66 |
+
response.raise_for_status()
|
67 |
+
data = json.loads(response.content.decode())
|
68 |
+
hourly_data = _extract_hourly_data(data)
|
69 |
+
if date is not None:
|
70 |
+
date = datetime.fromisoformat(date)
|
71 |
+
hourly_data = _filter_by_date(date, hourly_data)
|
72 |
+
return hourly_data
|
73 |
+
|
74 |
+
|
75 |
+
def get_wind_forecast(lat: float, lon: float, date: str | None = None) -> list[dict]:
|
76 |
+
"""Get wind forecast for given location.
|
77 |
+
|
78 |
+
Forecast will include:
|
79 |
+
|
80 |
+
- wind_direction (degrees)
|
81 |
+
- wind_speed (meters per second)
|
82 |
+
|
83 |
+
Args:
|
84 |
+
lat (float): Latitude of the location.
|
85 |
+
lon (float): Longitude of the location.
|
86 |
+
date (str | None): Date to filter by in any valid ISO 8601 format.
|
87 |
+
If not provided, all data (default to 6 days forecast) will be returned.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
list[dict]: Hourly data for wind forecast.
|
91 |
+
Example output:
|
92 |
+
|
93 |
+
```json
|
94 |
+
[
|
95 |
+
{"time": "2025-03-18T22:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.45, "sea_level_height_msl": -1.27},
|
96 |
+
{"time": "2025-03-18T23:00", "wave_direction": 264, "wave_height": 2.24, "wave_period": 10.35, "sea_level_height_msl": -1.35},
|
97 |
+
]
|
98 |
+
```
|
99 |
+
"""
|
100 |
+
url = "https://api.open-meteo.com/v1/forecast"
|
101 |
+
params = {
|
102 |
+
"latitude": lat,
|
103 |
+
"longitude": lon,
|
104 |
+
"hourly": ["winddirection_10m", "windspeed_10m"],
|
105 |
+
}
|
106 |
+
response = requests.get(url, params=params)
|
107 |
+
response.raise_for_status()
|
108 |
+
data = json.loads(response.content.decode())
|
109 |
+
hourly_data = _extract_hourly_data(data)
|
110 |
+
if date is not None:
|
111 |
+
date = datetime.fromisoformat(date)
|
112 |
+
hourly_data = _filter_by_date(date, hourly_data)
|
113 |
+
return hourly_data
|
src/surf_spot_finder/tools/openstreetmap.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
def get_area_lat_lon(area_name: str) -> tuple[float, float]:
|
6 |
+
"""Get the latitude and longitude of an area from Nominatim.
|
7 |
+
|
8 |
+
Uses the [Nominatim API](https://nominatim.org/release-docs/develop/api/Search/).
|
9 |
+
|
10 |
+
Args:
|
11 |
+
area_name (str): The name of the area.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
dict: The area found.
|
15 |
+
"""
|
16 |
+
response = requests.get(
|
17 |
+
f"https://nominatim.openstreetmap.org/search?q={area_name}&format=json",
|
18 |
+
headers={"User-Agent": "Mozilla/5.0"},
|
19 |
+
)
|
20 |
+
response.raise_for_status()
|
21 |
+
area = json.loads(response.content.decode())
|
22 |
+
return area[0]["lat"], area[0]["lon"]
|
23 |
+
|
24 |
+
|
25 |
+
def driving_hours_to_meters(driving_hours: int) -> int:
|
26 |
+
"""Convert driving hours to meters assuming a 70 km/h average speed.
|
27 |
+
|
28 |
+
|
29 |
+
Args:
|
30 |
+
driving_hours (int): The driving hours.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
int: The distance in meters.
|
34 |
+
"""
|
35 |
+
return driving_hours * 70 * 1000
|
36 |
+
|
37 |
+
|
38 |
+
def get_lat_lon_center(bounds: dict) -> tuple[float, float]:
|
39 |
+
"""Get the latitude and longitude of the center of a bounding box.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
bounds (dict): The bounding box.
|
43 |
+
|
44 |
+
```json
|
45 |
+
{
|
46 |
+
"minlat": float,
|
47 |
+
"minlon": float,
|
48 |
+
"maxlat": float,
|
49 |
+
"maxlon": float,
|
50 |
+
}
|
51 |
+
```
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
tuple: The latitude and longitude of the center.
|
55 |
+
"""
|
56 |
+
return (
|
57 |
+
(bounds["minlat"] + bounds["maxlat"]) / 2,
|
58 |
+
(bounds["minlon"] + bounds["maxlon"]) / 2,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def get_surfing_places(
|
63 |
+
lat: float, lon: float, radius: int
|
64 |
+
) -> list[tuple[str, tuple[float, float]]]:
|
65 |
+
"""Get surfing places around a given latitude and longitude.
|
66 |
+
|
67 |
+
Uses the [Overpass API](https://wiki.openstreetmap.org/wiki/Overpass_API).
|
68 |
+
|
69 |
+
Args:
|
70 |
+
lat (float): The latitude.
|
71 |
+
lon (float): The longitude.
|
72 |
+
radius (int): The radius in meters.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
dict: The surfing places found.
|
76 |
+
"""
|
77 |
+
overpass_url = "https://overpass-api.de/api/interpreter"
|
78 |
+
query = "[out:json];("
|
79 |
+
query += f'nwr["natural"="beach"](around:{radius},{lat},{lon});'
|
80 |
+
query += f'nwr["natural"="reef"](around:{radius},{lat},{lon});'
|
81 |
+
query += ");out body geom;"
|
82 |
+
params = {"data": query}
|
83 |
+
response = requests.get(
|
84 |
+
overpass_url, params=params, headers={"User-Agent": "Mozilla/5.0"}
|
85 |
+
)
|
86 |
+
response.raise_for_status()
|
87 |
+
elements = response.json()["elements"]
|
88 |
+
return [
|
89 |
+
(element.get("tags", {}).get("name", ""), get_lat_lon_center(element["bounds"]))
|
90 |
+
for element in elements
|
91 |
+
if "surfing" in element.get("tags", {}).get("sport", "")
|
92 |
+
]
|
tests/unit/agents/test_unit_openai.py
CHANGED
@@ -28,7 +28,7 @@ def test_run_openai_agent_default():
|
|
28 |
model="gpt-4o",
|
29 |
instructions=SINGLE_AGENT_SYSTEM_PROMPT,
|
30 |
name="surf-spot-finder",
|
31 |
-
tools=
|
32 |
)
|
33 |
|
34 |
|
|
|
28 |
model="gpt-4o",
|
29 |
instructions=SINGLE_AGENT_SYSTEM_PROMPT,
|
30 |
name="surf-spot-finder",
|
31 |
+
tools=ANY,
|
32 |
)
|
33 |
|
34 |
|
tests/unit/tools/test_unit_openmeteo.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
from unittest.mock import patch, MagicMock
|
4 |
+
|
5 |
+
from surf_spot_finder.tools import openmeteo
|
6 |
+
|
7 |
+
|
8 |
+
def test_extract_hourly_data():
|
9 |
+
data = {
|
10 |
+
"hourly": {
|
11 |
+
"time": ["2023-01-01T00:00", "2023-01-01T01:00"],
|
12 |
+
"wave_height": [1.5, 1.6],
|
13 |
+
"wave_period": [10, 11],
|
14 |
+
}
|
15 |
+
}
|
16 |
+
expected = [
|
17 |
+
{"time": "2023-01-01T00:00", "wave_height": 1.5, "wave_period": 10},
|
18 |
+
{"time": "2023-01-01T01:00", "wave_height": 1.6, "wave_period": 11},
|
19 |
+
]
|
20 |
+
assert openmeteo._extract_hourly_data(data) == expected
|
21 |
+
|
22 |
+
|
23 |
+
def test_filter_by_date():
|
24 |
+
hourly_data = [
|
25 |
+
{"time": "2023-01-01T00:00", "wave_height": 1.5},
|
26 |
+
{"time": "2023-01-01T01:00", "wave_height": 1.6},
|
27 |
+
{"time": "2023-01-01T02:00", "wave_height": 1.7},
|
28 |
+
{"time": "2023-01-01T03:00", "wave_height": 1.8},
|
29 |
+
]
|
30 |
+
date = datetime.fromisoformat("2023-01-01T01:00")
|
31 |
+
expected = [
|
32 |
+
{"time": "2023-01-01T00:00", "wave_height": 1.5},
|
33 |
+
{"time": "2023-01-01T01:00", "wave_height": 1.6},
|
34 |
+
{"time": "2023-01-01T02:00", "wave_height": 1.7},
|
35 |
+
]
|
36 |
+
assert openmeteo._filter_by_date(date, hourly_data) == expected
|
37 |
+
|
38 |
+
expected = [
|
39 |
+
{"time": "2023-01-01T01:00", "wave_height": 1.6},
|
40 |
+
]
|
41 |
+
assert openmeteo._filter_by_date(date, hourly_data, timedelta(hours=0)) == expected
|
42 |
+
|
43 |
+
|
44 |
+
def test_get_wave_forecast():
|
45 |
+
with patch("requests.get") as mock_get:
|
46 |
+
mock_response = MagicMock()
|
47 |
+
mock_response.status_code = 200
|
48 |
+
mock_response.content.decode.return_value = json.dumps(
|
49 |
+
{
|
50 |
+
"hourly": {
|
51 |
+
"time": ["2023-02-02T00:00", "2023-02-02T01:00"],
|
52 |
+
"wave_direction": [270, 280],
|
53 |
+
"wave_height": [1.5, 1.6],
|
54 |
+
"wave_period": [10, 11],
|
55 |
+
"sea_level_height_msl": [0.5, 0.6],
|
56 |
+
}
|
57 |
+
}
|
58 |
+
)
|
59 |
+
mock_get.return_value = mock_response
|
60 |
+
result = openmeteo.get_wave_forecast(lat=40.0, lon=-3.0)
|
61 |
+
|
62 |
+
assert len(result) == 2
|
63 |
+
assert result[1]["time"] == "2023-02-02T01:00"
|
64 |
+
assert result[1]["wave_direction"] == 280
|
65 |
+
assert result[1]["wave_height"] == 1.6
|
66 |
+
assert result[1]["wave_period"] == 11
|
67 |
+
assert result[1]["sea_level_height_msl"] == 0.6
|
68 |
+
|
69 |
+
result_filtered = openmeteo.get_wave_forecast(
|
70 |
+
lat=40.0, lon=-3.0, date="2023-02-02T02:00"
|
71 |
+
)
|
72 |
+
assert len(result_filtered) == 1
|
73 |
+
assert result_filtered[0]["time"] == "2023-02-02T01:00"
|
74 |
+
|
75 |
+
|
76 |
+
def test_get_wind_forecast():
|
77 |
+
with patch("requests.get") as mock_get:
|
78 |
+
mock_response = MagicMock()
|
79 |
+
mock_response.status_code = 200
|
80 |
+
mock_response.content.decode.return_value = json.dumps(
|
81 |
+
{
|
82 |
+
"hourly": {
|
83 |
+
"time": ["2023-02-02T00:00", "2023-02-02T01:00"],
|
84 |
+
"winddirection_10m": [270, 280],
|
85 |
+
"windspeed_10m": [10, 11],
|
86 |
+
}
|
87 |
+
}
|
88 |
+
)
|
89 |
+
mock_get.return_value = mock_response
|
90 |
+
|
91 |
+
result = openmeteo.get_wind_forecast(lat=40.0, lon=-3.0)
|
92 |
+
assert len(result) == 2
|
93 |
+
assert result[1]["time"] == "2023-02-02T01:00"
|
94 |
+
assert result[1]["winddirection_10m"] == 280
|
95 |
+
assert result[1]["windspeed_10m"] == 11
|
96 |
+
|
97 |
+
result_filtered = openmeteo.get_wind_forecast(
|
98 |
+
lat=40.0, lon=-3.0, date="2023-02-02T02:00"
|
99 |
+
)
|
100 |
+
assert len(result_filtered) == 1
|
101 |
+
assert result_filtered[0]["time"] == "2023-02-02T01:00"
|
102 |
+
assert result[0]["windspeed_10m"] == 10
|
tests/unit/tools/test_unit_openstreetmap.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from unittest.mock import MagicMock, patch
|
3 |
+
|
4 |
+
from surf_spot_finder.tools import openstreetmap
|
5 |
+
|
6 |
+
|
7 |
+
def test_get_area_lat_lon():
|
8 |
+
with patch("requests.get") as mock_get:
|
9 |
+
mock_response = MagicMock()
|
10 |
+
mock_response.status_code = 200
|
11 |
+
mock_response.content.decode.return_value = json.dumps(
|
12 |
+
[{"lat": "40.0", "lon": "-3.0"}]
|
13 |
+
)
|
14 |
+
mock_get.return_value = mock_response
|
15 |
+
|
16 |
+
lat, lon = openstreetmap.get_area_lat_lon("Madrid")
|
17 |
+
assert lat == "40.0"
|
18 |
+
assert lon == "-3.0"
|
19 |
+
|
20 |
+
|
21 |
+
def test_driving_hours_to_meters():
|
22 |
+
assert openstreetmap.driving_hours_to_meters(1) == 70000
|
23 |
+
|
24 |
+
|
25 |
+
def test_get_lat_lon_center():
|
26 |
+
bounds = {"minlat": 40.0, "minlon": -3.0, "maxlat": 41.0, "maxlon": -2.0}
|
27 |
+
lat, lon = openstreetmap.get_lat_lon_center(bounds)
|
28 |
+
assert lat == 40.5
|
29 |
+
assert lon == -2.5
|
30 |
+
|
31 |
+
|
32 |
+
def test_get_surfing_places():
|
33 |
+
with patch("requests.get") as mock_get:
|
34 |
+
mock_response = MagicMock()
|
35 |
+
mock_response.status_code = 200
|
36 |
+
mock_response.json.return_value = {
|
37 |
+
"elements": [
|
38 |
+
{
|
39 |
+
"tags": {"name": "Surf Spot 1", "sport": "surfing"},
|
40 |
+
"bounds": {
|
41 |
+
"minlat": 40.0,
|
42 |
+
"minlon": -3.0,
|
43 |
+
"maxlat": 40.1,
|
44 |
+
"maxlon": -2.9,
|
45 |
+
},
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"tags": {"name": "Beach 2", "sport": "swimming"},
|
49 |
+
"bounds": {
|
50 |
+
"minlat": 41.0,
|
51 |
+
"minlon": -4.0,
|
52 |
+
"maxlat": 41.1,
|
53 |
+
"maxlon": -3.9,
|
54 |
+
},
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"tags": {"name": "Surf Spot 3", "sport": "surfing"},
|
58 |
+
"bounds": {
|
59 |
+
"minlat": 42.0,
|
60 |
+
"minlon": -5.0,
|
61 |
+
"maxlat": 42.1,
|
62 |
+
"maxlon": -4.9,
|
63 |
+
},
|
64 |
+
},
|
65 |
+
]
|
66 |
+
}
|
67 |
+
mock_get.return_value = mock_response
|
68 |
+
|
69 |
+
results = openstreetmap.get_surfing_places(lat=40.5, lon=-3.5, radius=10000)
|
70 |
+
assert len(results) == 2
|
71 |
+
assert results[0][0] == "Surf Spot 1"
|
72 |
+
assert results[0][1] == (40.05, -2.95)
|
73 |
+
assert results[1][0] == "Surf Spot 3"
|
74 |
+
assert results[1][1] == (42.05, -4.95)
|