Spaces:
Running
Running
David de la Iglesia Castro
commited on
enh(cli): Add `from_config` option. (#17)
Browse files* Add run_openai_multi_agent.
* config: Simplify DEFAULT_PROMPT.
* fix(test_unit_openai): Import after patch
* Update insrtuctions. Add communication agent
* Iterate on instructions
* enh(tracing): Handle openai_multi_agent. Add `test_setup_tracing`.
* More instructions tuning
* Add test_run_openai_multiagent
* More tweaks
* enh(cli): Add `from_config` option.
- Add `examples` dir with input configs.
- Move prompts to `surf_spot_finder.prompts`.
* fix api.md
* fix(test_unit_smolagents): Don't rely on patching `sys`.
- docs/api.md +8 -2
- examples/openai_multi_agent.yaml +6 -0
- examples/openai_single_agent.yaml +6 -0
- examples/smolagents_single_agent.yaml +7 -0
- src/surf_spot_finder/agents/openai.py +10 -19
- src/surf_spot_finder/agents/smolagents.py +11 -12
- src/surf_spot_finder/cli.py +51 -19
- src/surf_spot_finder/config.py +4 -9
- src/surf_spot_finder/prompts/__init__.py +0 -0
- src/surf_spot_finder/prompts/openai.py +20 -0
- src/surf_spot_finder/prompts/shared.py +5 -0
- src/surf_spot_finder/{agents/prompts → prompts}/smolagents.py +0 -0
- tests/unit/agents/test_unit_openai.py +6 -3
- tests/unit/agents/test_unit_smolagents.py +36 -55
docs/api.md
CHANGED
@@ -1,13 +1,19 @@
|
|
1 |
# API Reference
|
2 |
|
|
|
|
|
3 |
::: surf_spot_finder.config.Config
|
4 |
|
5 |
::: surf_spot_finder.agents.RUNNERS
|
6 |
|
7 |
::: surf_spot_finder.agents.openai
|
8 |
|
9 |
-
::: surf_spot_finder.agents.openai.DEFAULT_MULTIAGENT_INSTRUCTIONS
|
10 |
-
|
11 |
::: surf_spot_finder.agents.smolagents
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
::: surf_spot_finder.tracing
|
|
|
1 |
# API Reference
|
2 |
|
3 |
+
::: surf_spot_finder.cli
|
4 |
+
|
5 |
::: surf_spot_finder.config.Config
|
6 |
|
7 |
::: surf_spot_finder.agents.RUNNERS
|
8 |
|
9 |
::: surf_spot_finder.agents.openai
|
10 |
|
|
|
|
|
11 |
::: surf_spot_finder.agents.smolagents
|
12 |
|
13 |
+
::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
|
14 |
+
|
15 |
+
::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
|
16 |
+
|
17 |
+
::: surf_spot_finder.prompts.shared.INPUT_PROMPT
|
18 |
+
|
19 |
::: surf_spot_finder.tracing
|
examples/openai_multi_agent.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: o3-mini
|
5 |
+
agent_type: openai_multi_agent
|
6 |
+
# input_prompt_template:
|
examples/openai_single_agent.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: o3-mini
|
5 |
+
agent_type: openai
|
6 |
+
# input_prompt_template:
|
examples/smolagents_single_agent.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
location: Pontevedra
|
2 |
+
date: 2025-03-22 12:00
|
3 |
+
max_driving_hours: 2
|
4 |
+
model_id: openai/gpt-3.5-turbo
|
5 |
+
api_key_var: OPENAI_API_KEY
|
6 |
+
agent_type: smolagents
|
7 |
+
# input_prompt_template:
|
src/surf_spot_finder/agents/openai.py
CHANGED
@@ -17,6 +17,12 @@ from smolagents import (
|
|
17 |
)
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
@function_tool
|
21 |
def search_web(query: str) -> str:
|
22 |
"""Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
|
@@ -69,7 +75,7 @@ def run_openai_agent(
|
|
69 |
model_id: str,
|
70 |
prompt: str,
|
71 |
name: str = "surf-spot-finder",
|
72 |
-
instructions: Optional[str] =
|
73 |
api_key_var: Optional[str] = None,
|
74 |
base_url: Optional[str] = None,
|
75 |
) -> RunResult:
|
@@ -87,7 +93,7 @@ def run_openai_agent(
|
|
87 |
prompt (str): The prompt to be given to the agent.
|
88 |
name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
|
89 |
instructions (Optional[str], optional): Initial instructions to give the agent.
|
90 |
-
Defaults to
|
91 |
api_key_var (Optional[str], optional): The name of the environment variable
|
92 |
containing the OpenAI API key. If provided, along with `base_url`, an
|
93 |
external OpenAI client will be used. Defaults to None.
|
@@ -126,27 +132,12 @@ def run_openai_agent(
|
|
126 |
return result
|
127 |
|
128 |
|
129 |
-
DEFAULT_MULTIAGENT_INSTRUCTIONS = """
|
130 |
-
You will be asked to perform a task.
|
131 |
-
|
132 |
-
Always follow this steps:
|
133 |
-
|
134 |
-
First, before solving the task, look at the available agent/tools and plan a sequence of actions using the available tools.
|
135 |
-
Second, show the plan of actions and ask for user verification. If the user does not verify the plan, come up with a better plan.
|
136 |
-
Third, execute the plan using the available tools, until you get a final answer.
|
137 |
-
|
138 |
-
Once you get a final answer, show it and ask for user verification. If the user does not verify the answer, come up with a better answer.
|
139 |
-
|
140 |
-
Finally, use the available handoff tool (`transfer_to_<agent_name>`) to communicate it to the user.
|
141 |
-
"""
|
142 |
-
|
143 |
-
|
144 |
@logger.catch(reraise=True)
|
145 |
def run_openai_multi_agent(
|
146 |
model_id: str,
|
147 |
prompt: str,
|
148 |
name: str = "surf-spot-finder",
|
149 |
-
instructions: Optional[str] =
|
150 |
) -> RunResult:
|
151 |
"""Runs multiple OpenAI agents orchestrated by a main agent.
|
152 |
|
@@ -162,7 +153,7 @@ def run_openai_multi_agent(
|
|
162 |
prompt (str): The prompt to be given to the agent.
|
163 |
name (str, optional): The name of the main agent. Defaults to "surf-spot-finder".
|
164 |
instructions (Optional[str], optional): Initial instructions to give the agent.
|
165 |
-
Defaults to [
|
166 |
|
167 |
Returns:
|
168 |
RunResult: A RunResult object containing the output of the agent run.
|
|
|
17 |
)
|
18 |
|
19 |
|
20 |
+
from surf_spot_finder.prompts.openai import (
|
21 |
+
SINGLE_AGENT_SYSTEM_PROMPT,
|
22 |
+
MULTI_AGENT_SYSTEM_PROMPT,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
@function_tool
|
27 |
def search_web(query: str) -> str:
|
28 |
"""Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
|
|
|
75 |
model_id: str,
|
76 |
prompt: str,
|
77 |
name: str = "surf-spot-finder",
|
78 |
+
instructions: Optional[str] = SINGLE_AGENT_SYSTEM_PROMPT,
|
79 |
api_key_var: Optional[str] = None,
|
80 |
base_url: Optional[str] = None,
|
81 |
) -> RunResult:
|
|
|
93 |
prompt (str): The prompt to be given to the agent.
|
94 |
name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
|
95 |
instructions (Optional[str], optional): Initial instructions to give the agent.
|
96 |
+
Defaults to [SINGLE_AGENT_SYSTEM_PROMPT][surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT].
|
97 |
api_key_var (Optional[str], optional): The name of the environment variable
|
98 |
containing the OpenAI API key. If provided, along with `base_url`, an
|
99 |
external OpenAI client will be used. Defaults to None.
|
|
|
132 |
return result
|
133 |
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
@logger.catch(reraise=True)
|
136 |
def run_openai_multi_agent(
|
137 |
model_id: str,
|
138 |
prompt: str,
|
139 |
name: str = "surf-spot-finder",
|
140 |
+
instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
|
141 |
) -> RunResult:
|
142 |
"""Runs multiple OpenAI agents orchestrated by a main agent.
|
143 |
|
|
|
153 |
prompt (str): The prompt to be given to the agent.
|
154 |
name (str, optional): The name of the main agent. Defaults to "surf-spot-finder".
|
155 |
instructions (Optional[str], optional): Initial instructions to give the agent.
|
156 |
+
Defaults to [MULTI_AGENT_SYSTEM_PROMPT][surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT].
|
157 |
|
158 |
Returns:
|
159 |
RunResult: A RunResult object containing the output of the agent run.
|
src/surf_spot_finder/agents/smolagents.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
import os
|
2 |
-
from typing import Optional
|
3 |
|
4 |
from loguru import logger
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
@logger.catch(reraise=True)
|
@@ -13,7 +19,7 @@ def run_smolagent(
|
|
13 |
prompt: str,
|
14 |
api_key_var: Optional[str] = None,
|
15 |
api_base: Optional[str] = None,
|
16 |
-
) ->
|
17 |
"""
|
18 |
Create and configure a Smolagents CodeAgent with the specified model.
|
19 |
|
@@ -29,17 +35,10 @@ def run_smolagent(
|
|
29 |
CodeAgent: Configured agent ready to process requests
|
30 |
|
31 |
Example:
|
|
|
32 |
>>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
|
33 |
>>> agent.run("Find surf spots near San Diego")
|
34 |
"""
|
35 |
-
from smolagents import ( # pylint: disable=import-outside-toplevel
|
36 |
-
CodeAgent,
|
37 |
-
DuckDuckGoSearchTool,
|
38 |
-
LiteLLMModel,
|
39 |
-
ToolCollection,
|
40 |
-
)
|
41 |
-
from mcp import StdioServerParameters
|
42 |
-
from surf_spot_finder.agents.prompts.smolagents import SYSTEM_PROMPT
|
43 |
|
44 |
model = LiteLLMModel(
|
45 |
model_id=model_id,
|
|
|
1 |
import os
|
2 |
+
from typing import Optional
|
3 |
|
4 |
from loguru import logger
|
5 |
|
6 |
+
from smolagents import (
|
7 |
+
CodeAgent,
|
8 |
+
DuckDuckGoSearchTool,
|
9 |
+
LiteLLMModel,
|
10 |
+
ToolCollection,
|
11 |
+
)
|
12 |
+
from mcp import StdioServerParameters
|
13 |
+
from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
|
14 |
|
15 |
|
16 |
@logger.catch(reraise=True)
|
|
|
19 |
prompt: str,
|
20 |
api_key_var: Optional[str] = None,
|
21 |
api_base: Optional[str] = None,
|
22 |
+
) -> CodeAgent:
|
23 |
"""
|
24 |
Create and configure a Smolagents CodeAgent with the specified model.
|
25 |
|
|
|
35 |
CodeAgent: Configured agent ready to process requests
|
36 |
|
37 |
Example:
|
38 |
+
|
39 |
>>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
|
40 |
>>> agent.run("Find surf spots near San Diego")
|
41 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
model = LiteLLMModel(
|
44 |
model_id=model_id,
|
src/surf_spot_finder/cli.py
CHANGED
@@ -1,40 +1,72 @@
|
|
|
|
1 |
from typing import Optional
|
2 |
|
|
|
3 |
from fire import Fire
|
4 |
from loguru import logger
|
5 |
|
6 |
from surf_spot_finder.config import (
|
7 |
Config,
|
8 |
-
DEFAULT_PROMPT,
|
9 |
)
|
10 |
from surf_spot_finder.agents import RUNNERS
|
|
|
11 |
from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
|
12 |
|
13 |
|
14 |
@logger.catch(reraise=True)
|
15 |
def find_surf_spot(
|
16 |
-
location: str,
|
17 |
-
date: str,
|
18 |
-
max_driving_hours: int,
|
19 |
-
model_id: str,
|
20 |
agent_type: str = "smolagents",
|
21 |
api_key_var: Optional[str] = None,
|
22 |
-
|
23 |
json_tracer: bool = True,
|
24 |
api_base: Optional[str] = None,
|
|
|
25 |
):
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
logger.info("Setting up tracing")
|
40 |
tracer_provider, _ = get_tracer_provider(
|
@@ -45,7 +77,7 @@ def find_surf_spot(
|
|
45 |
logger.info(f"Running {config.agent_type} agent")
|
46 |
RUNNERS[config.agent_type](
|
47 |
model_id=config.model_id,
|
48 |
-
prompt=config.
|
49 |
LOCATION=config.location,
|
50 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
51 |
DATE=config.date,
|
|
|
1 |
+
from pathlib import Path
|
2 |
from typing import Optional
|
3 |
|
4 |
+
import yaml
|
5 |
from fire import Fire
|
6 |
from loguru import logger
|
7 |
|
8 |
from surf_spot_finder.config import (
|
9 |
Config,
|
|
|
10 |
)
|
11 |
from surf_spot_finder.agents import RUNNERS
|
12 |
+
from surf_spot_finder.prompts.shared import INPUT_PROMPT
|
13 |
from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
|
14 |
|
15 |
|
16 |
@logger.catch(reraise=True)
|
17 |
def find_surf_spot(
|
18 |
+
location: Optional[str] = None,
|
19 |
+
date: Optional[str] = None,
|
20 |
+
max_driving_hours: Optional[int] = None,
|
21 |
+
model_id: Optional[str] = None,
|
22 |
agent_type: str = "smolagents",
|
23 |
api_key_var: Optional[str] = None,
|
24 |
+
input_prompt_template: str = INPUT_PROMPT,
|
25 |
json_tracer: bool = True,
|
26 |
api_base: Optional[str] = None,
|
27 |
+
from_config: Optional[str] = None,
|
28 |
):
|
29 |
+
"""Find the best surf spot based on the given criteria.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
location (str): The location to search around.
|
33 |
+
Required if `from_config` is not provided.
|
34 |
+
date (str): The date to search for.
|
35 |
+
Required if `from_config` is not provided.
|
36 |
+
max_driving_hours (int): The maximum driving hours from the location.
|
37 |
+
Required if `from_config` is not provided.
|
38 |
+
model_id (str): The ID of the model to use.
|
39 |
+
Required if `from_config` is not provided.
|
40 |
+
|
41 |
+
If using `agent_type=smolagents`, use LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet').
|
42 |
+
If using `agent_type={openai,openai_multi_agent}`, use OpenAI syntax (e.g., 'o1').
|
43 |
+
agent_type (str, optional): The type of agent to use.
|
44 |
+
Must be one of the supported types in [RUNNERS][surf_spot_finder.agents.RUNNERS].
|
45 |
+
api_key_var (Optional[str], optional): The name of the environment variable containing the API key.
|
46 |
+
input_prompt_template (str, optional): The template for the imput_prompt.
|
47 |
+
|
48 |
+
Must contain the following placeholders: `{LOCATION}`, `{MAX_DRIVING_HOURS}`, and `{DATE}`.
|
49 |
+
json_tracer (bool, optional): Whether to use the custom JSON file exporter.
|
50 |
+
api_base (Optional[str], optional): The base URL for the API.
|
51 |
+
from_config (Optional[str], optional): Path to a YAML config file.
|
52 |
+
|
53 |
+
If provided, all other arguments will be ignored.
|
54 |
+
"""
|
55 |
+
if from_config:
|
56 |
+
logger.info(f"Loading {from_config}")
|
57 |
+
config = Config.model_validate(yaml.safe_load(Path(from_config).read_text()))
|
58 |
+
else:
|
59 |
+
config = Config(
|
60 |
+
location=location,
|
61 |
+
date=date,
|
62 |
+
max_driving_hours=max_driving_hours,
|
63 |
+
model_id=model_id,
|
64 |
+
agent_type=agent_type,
|
65 |
+
api_key_var=api_key_var,
|
66 |
+
prompt=input_prompt_template,
|
67 |
+
json_tracer=json_tracer,
|
68 |
+
api_base=api_base,
|
69 |
+
)
|
70 |
|
71 |
logger.info("Setting up tracing")
|
72 |
tracer_provider, _ = get_tracer_provider(
|
|
|
77 |
logger.info(f"Running {config.agent_type} agent")
|
78 |
RUNNERS[config.agent_type](
|
79 |
model_id=config.model_id,
|
80 |
+
prompt=config.input_prompt_template.format(
|
81 |
LOCATION=config.location,
|
82 |
MAX_DRIVING_HOURS=config.max_driving_hours,
|
83 |
DATE=config.date,
|
src/surf_spot_finder/config.py
CHANGED
@@ -1,14 +1,7 @@
|
|
1 |
from typing import Annotated, Optional
|
2 |
from pydantic import AfterValidator, BaseModel, FutureDatetime, PositiveInt
|
3 |
-
from datetime import datetime
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
DEFAULT_PROMPT = (
|
8 |
-
"What will be the best surf spot around {LOCATION}"
|
9 |
-
", in a {MAX_DRIVING_HOURS} hour driving radius"
|
10 |
-
", at {DATE}?"
|
11 |
-
)
|
12 |
|
13 |
|
14 |
def validate_prompt(value) -> str:
|
@@ -26,7 +19,9 @@ def validate_agent_type(value) -> str:
|
|
26 |
|
27 |
|
28 |
class Config(BaseModel):
|
29 |
-
|
|
|
|
|
30 |
location: str
|
31 |
max_driving_hours: PositiveInt
|
32 |
date: FutureDatetime
|
|
|
1 |
from typing import Annotated, Optional
|
2 |
from pydantic import AfterValidator, BaseModel, FutureDatetime, PositiveInt
|
|
|
3 |
|
4 |
+
from surf_spot_finder.prompts.shared import INPUT_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def validate_prompt(value) -> str:
|
|
|
19 |
|
20 |
|
21 |
class Config(BaseModel):
|
22 |
+
input_prompt_template: Annotated[str, AfterValidator(validate_prompt)] = (
|
23 |
+
INPUT_PROMPT
|
24 |
+
)
|
25 |
location: str
|
26 |
max_driving_hours: PositiveInt
|
27 |
date: FutureDatetime
|
src/surf_spot_finder/prompts/__init__.py
ADDED
File without changes
|
src/surf_spot_finder/prompts/openai.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SINGLE_AGENT_SYSTEM_PROMPT = """
|
2 |
+
You will be asked to perform a task.
|
3 |
+
|
4 |
+
Before solving the task, plan a sequence of actions using the available tools.
|
5 |
+
Then, execute the sequence of actions using the tools.
|
6 |
+
""".strip()
|
7 |
+
|
8 |
+
MULTI_AGENT_SYSTEM_PROMPT = """
|
9 |
+
You will be asked to perform a task.
|
10 |
+
|
11 |
+
Always follow this steps:
|
12 |
+
|
13 |
+
First, before solving the task, plan a sequence of actions using the available tools.
|
14 |
+
Second, show the plan of actions and ask for user verification. If the user does not verify the plan, come up with a better plan.
|
15 |
+
Third, execute the plan using the available tools, until you get a final answer.
|
16 |
+
|
17 |
+
Once you get a final answer, show it and ask for user verification. If the user does not verify the answer, come up with a better answer.
|
18 |
+
|
19 |
+
Finally, use the available handoff tool (`transfer_to_<agent_name>`) to communicate it to the user.
|
20 |
+
""".strip()
|
src/surf_spot_finder/prompts/shared.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INPUT_PROMPT = """
|
2 |
+
According to the forecast, what will be the best spot to surf around {LOCATION},
|
3 |
+
in a {MAX_DRIVING_HOURS} hour driving radius,
|
4 |
+
at {DATE}?"
|
5 |
+
""".strip()
|
src/surf_spot_finder/{agents/prompts → prompts}/smolagents.py
RENAMED
File without changes
|
tests/unit/agents/test_unit_openai.py
CHANGED
@@ -9,7 +9,10 @@ from surf_spot_finder.agents.openai import (
|
|
9 |
search_web,
|
10 |
user_verification,
|
11 |
visit_webpage,
|
12 |
-
|
|
|
|
|
|
|
13 |
)
|
14 |
|
15 |
|
@@ -23,7 +26,7 @@ def test_run_openai_agent_default():
|
|
23 |
run_openai_agent("gpt-4o", "Test prompt")
|
24 |
mock_agent.assert_called_once_with(
|
25 |
model="gpt-4o",
|
26 |
-
instructions=
|
27 |
name="surf-spot-finder",
|
28 |
tools=[search_web, visit_webpage],
|
29 |
)
|
@@ -91,7 +94,7 @@ def test_run_openai_multiagent():
|
|
91 |
|
92 |
mock_agent.assert_any_call(
|
93 |
model="gpt-4o",
|
94 |
-
instructions=
|
95 |
name="surf-spot-finder",
|
96 |
# TODO: add more elaborated checks
|
97 |
handoffs=ANY,
|
|
|
9 |
search_web,
|
10 |
user_verification,
|
11 |
visit_webpage,
|
12 |
+
)
|
13 |
+
from surf_spot_finder.prompts.openai import (
|
14 |
+
SINGLE_AGENT_SYSTEM_PROMPT,
|
15 |
+
MULTI_AGENT_SYSTEM_PROMPT,
|
16 |
)
|
17 |
|
18 |
|
|
|
26 |
run_openai_agent("gpt-4o", "Test prompt")
|
27 |
mock_agent.assert_called_once_with(
|
28 |
model="gpt-4o",
|
29 |
+
instructions=SINGLE_AGENT_SYSTEM_PROMPT,
|
30 |
name="surf-spot-finder",
|
31 |
tools=[search_web, visit_webpage],
|
32 |
)
|
|
|
94 |
|
95 |
mock_agent.assert_any_call(
|
96 |
model="gpt-4o",
|
97 |
+
instructions=MULTI_AGENT_SYSTEM_PROMPT,
|
98 |
name="surf-spot-finder",
|
99 |
# TODO: add more elaborated checks
|
100 |
handoffs=ANY,
|
tests/unit/agents/test_unit_smolagents.py
CHANGED
@@ -1,103 +1,84 @@
|
|
1 |
import os
|
2 |
import pytest
|
3 |
from unittest.mock import patch, MagicMock
|
|
|
4 |
|
5 |
from surf_spot_finder.agents.smolagents import run_smolagent
|
6 |
|
7 |
|
8 |
@pytest.fixture
|
9 |
-
def
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
mock_litellm_model = MagicMock()
|
14 |
mock_tool_collection = MagicMock()
|
15 |
|
16 |
-
# Configure the mock tool collection to work as a context manager
|
17 |
mock_tool_collection.from_mcp.return_value.__enter__.return_value = (
|
18 |
mock_tool_collection
|
19 |
)
|
20 |
mock_tool_collection.from_mcp.return_value.__exit__.return_value = None
|
21 |
mock_tool_collection.tools = ["mock_tool"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
{
|
26 |
-
"smolagents": MagicMock(
|
27 |
-
CodeAgent=mock_code_agent,
|
28 |
-
DuckDuckGoSearchTool=mock_ddg_tool,
|
29 |
-
LiteLLMModel=mock_litellm_model,
|
30 |
-
ToolCollection=mock_tool_collection,
|
31 |
-
),
|
32 |
-
"mcp": MagicMock(
|
33 |
-
StdioServerParameters=MagicMock(),
|
34 |
-
),
|
35 |
-
},
|
36 |
-
):
|
37 |
-
yield {
|
38 |
-
"CodeAgent": mock_code_agent,
|
39 |
-
"DuckDuckGoSearchTool": mock_ddg_tool,
|
40 |
-
"LiteLLMModel": mock_litellm_model,
|
41 |
-
"ToolCollection": mock_tool_collection,
|
42 |
-
}
|
43 |
-
|
44 |
-
|
45 |
-
@pytest.mark.usefixtures("mock_smolagents_imports")
|
46 |
-
def test_run_smolagent_with_api_key_var():
|
47 |
-
"""Test smolagent creation with an API key from environment variable."""
|
48 |
-
# The patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"})
|
49 |
-
# is a testing construct that temporarily modifies the environment variables
|
50 |
-
# for the duration of the test.
|
51 |
-
# some tests use TEST_API_KEY while others don't
|
52 |
-
with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
|
53 |
-
from smolagents import CodeAgent, LiteLLMModel
|
54 |
|
|
|
55 |
run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
|
56 |
|
57 |
-
|
58 |
-
model_call_kwargs =
|
59 |
assert model_call_kwargs["model_id"] == "openai/gpt-4"
|
60 |
assert model_call_kwargs["api_key"] == "test-key-12345"
|
61 |
assert model_call_kwargs["api_base"] is None
|
62 |
|
63 |
-
|
64 |
-
|
65 |
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
"""Test smolagent creation with a custom API base."""
|
70 |
-
with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
|
71 |
-
from smolagents import LiteLLMModel
|
72 |
|
73 |
-
|
74 |
run_smolagent(
|
75 |
"anthropic/claude-3-sonnet",
|
76 |
"Test prompt",
|
77 |
api_key_var="TEST_API_KEY",
|
78 |
api_base="https://custom-api.example.com",
|
79 |
)
|
80 |
-
last_call =
|
81 |
|
82 |
assert last_call[1]["model_id"] == "anthropic/claude-3-sonnet"
|
83 |
assert last_call[1]["api_key"] == "test-key-12345"
|
84 |
assert last_call[1]["api_base"] == "https://custom-api.example.com"
|
85 |
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
"""You should be able to run the smolagent without an API key."""
|
90 |
-
from smolagents import LiteLLMModel
|
91 |
|
92 |
-
|
|
|
93 |
|
94 |
-
last_call =
|
95 |
assert last_call[1]["model_id"] == "ollama_chat/deepseek-r1"
|
96 |
assert last_call[1]["api_key"] is None
|
97 |
|
98 |
|
99 |
def test_run_smolagent_environment_error():
|
100 |
-
"""Test that passing a bad api_key_var throws an error"""
|
101 |
with patch.dict(os.environ, {}, clear=True):
|
102 |
with pytest.raises(KeyError, match="MISSING_KEY"):
|
103 |
run_smolagent("test-model", "Test prompt", api_key_var="MISSING_KEY")
|
|
|
1 |
import os
|
2 |
import pytest
|
3 |
from unittest.mock import patch, MagicMock
|
4 |
+
import contextlib
|
5 |
|
6 |
from surf_spot_finder.agents.smolagents import run_smolagent
|
7 |
|
8 |
|
9 |
@pytest.fixture
|
10 |
+
def common_patches():
|
11 |
+
litellm_model_mock = MagicMock()
|
12 |
+
code_agent_mock = MagicMock()
|
13 |
+
patch_context = contextlib.ExitStack()
|
|
|
14 |
mock_tool_collection = MagicMock()
|
15 |
|
|
|
16 |
mock_tool_collection.from_mcp.return_value.__enter__.return_value = (
|
17 |
mock_tool_collection
|
18 |
)
|
19 |
mock_tool_collection.from_mcp.return_value.__exit__.return_value = None
|
20 |
mock_tool_collection.tools = ["mock_tool"]
|
21 |
+
patch_context.enter_context(
|
22 |
+
patch("surf_spot_finder.agents.smolagents.StdioServerParameters", MagicMock())
|
23 |
+
)
|
24 |
+
patch_context.enter_context(
|
25 |
+
patch("surf_spot_finder.agents.smolagents.CodeAgent", code_agent_mock)
|
26 |
+
)
|
27 |
+
patch_context.enter_context(
|
28 |
+
patch("surf_spot_finder.agents.smolagents.LiteLLMModel", litellm_model_mock)
|
29 |
+
)
|
30 |
+
patch_context.enter_context(
|
31 |
+
patch("surf_spot_finder.agents.smolagents.ToolCollection", mock_tool_collection)
|
32 |
+
)
|
33 |
+
yield patch_context, litellm_model_mock, code_agent_mock, mock_tool_collection
|
34 |
+
patch_context.close()
|
35 |
+
|
36 |
|
37 |
+
def test_run_smolagent_with_api_key_var(common_patches):
|
38 |
+
patch_context, litellm_model_mock, code_agent_mock, *_ = common_patches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
|
41 |
run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
|
42 |
|
43 |
+
litellm_model_mock.assert_called()
|
44 |
+
model_call_kwargs = litellm_model_mock.call_args[1]
|
45 |
assert model_call_kwargs["model_id"] == "openai/gpt-4"
|
46 |
assert model_call_kwargs["api_key"] == "test-key-12345"
|
47 |
assert model_call_kwargs["api_base"] is None
|
48 |
|
49 |
+
code_agent_mock.assert_called_once()
|
50 |
+
code_agent_mock.return_value.run.assert_called_once_with("Test prompt")
|
51 |
|
52 |
|
53 |
+
def test_run_smolagent_with_custom_api_base(common_patches):
|
54 |
+
patch_context, litellm_model_mock, *_ = common_patches
|
|
|
|
|
|
|
55 |
|
56 |
+
with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
|
57 |
run_smolagent(
|
58 |
"anthropic/claude-3-sonnet",
|
59 |
"Test prompt",
|
60 |
api_key_var="TEST_API_KEY",
|
61 |
api_base="https://custom-api.example.com",
|
62 |
)
|
63 |
+
last_call = litellm_model_mock.call_args_list[-1]
|
64 |
|
65 |
assert last_call[1]["model_id"] == "anthropic/claude-3-sonnet"
|
66 |
assert last_call[1]["api_key"] == "test-key-12345"
|
67 |
assert last_call[1]["api_base"] == "https://custom-api.example.com"
|
68 |
|
69 |
|
70 |
+
def test_run_smolagent_without_api_key(common_patches):
|
71 |
+
patch_context, litellm_model_mock, *_ = common_patches
|
|
|
|
|
72 |
|
73 |
+
with patch_context:
|
74 |
+
run_smolagent("ollama_chat/deepseek-r1", "Test prompt")
|
75 |
|
76 |
+
last_call = litellm_model_mock.call_args_list[-1]
|
77 |
assert last_call[1]["model_id"] == "ollama_chat/deepseek-r1"
|
78 |
assert last_call[1]["api_key"] is None
|
79 |
|
80 |
|
81 |
def test_run_smolagent_environment_error():
|
|
|
82 |
with patch.dict(os.environ, {}, clear=True):
|
83 |
with pytest.raises(KeyError, match="MISSING_KEY"):
|
84 |
run_smolagent("test-model", "Test prompt", api_key_var="MISSING_KEY")
|