David de la Iglesia Castro commited on
Commit
502c6d3
·
unverified ·
1 Parent(s): ee1c4f8

enh(cli): Add `from_config` option. (#17)

Browse files

* Add run_openai_multi_agent.

* config: Simplify DEFAULT_PROMPT.

* fix(test_unit_openai): Import after patch

* Update insrtuctions. Add communication agent

* Iterate on instructions

* enh(tracing): Handle openai_multi_agent. Add `test_setup_tracing`.

* More instructions tuning

* Add test_run_openai_multiagent

* More tweaks

* enh(cli): Add `from_config` option.

- Add `examples` dir with input configs.
- Move prompts to `surf_spot_finder.prompts`.

* fix api.md

* fix(test_unit_smolagents): Don't rely on patching `sys`.

docs/api.md CHANGED
@@ -1,13 +1,19 @@
1
  # API Reference
2
 
 
 
3
  ::: surf_spot_finder.config.Config
4
 
5
  ::: surf_spot_finder.agents.RUNNERS
6
 
7
  ::: surf_spot_finder.agents.openai
8
 
9
- ::: surf_spot_finder.agents.openai.DEFAULT_MULTIAGENT_INSTRUCTIONS
10
-
11
  ::: surf_spot_finder.agents.smolagents
12
 
 
 
 
 
 
 
13
  ::: surf_spot_finder.tracing
 
1
  # API Reference
2
 
3
+ ::: surf_spot_finder.cli
4
+
5
  ::: surf_spot_finder.config.Config
6
 
7
  ::: surf_spot_finder.agents.RUNNERS
8
 
9
  ::: surf_spot_finder.agents.openai
10
 
 
 
11
  ::: surf_spot_finder.agents.smolagents
12
 
13
+ ::: surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT
14
+
15
+ ::: surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT
16
+
17
+ ::: surf_spot_finder.prompts.shared.INPUT_PROMPT
18
+
19
  ::: surf_spot_finder.tracing
examples/openai_multi_agent.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: o3-mini
5
+ agent_type: openai_multi_agent
6
+ # input_prompt_template:
examples/openai_single_agent.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: o3-mini
5
+ agent_type: openai
6
+ # input_prompt_template:
examples/smolagents_single_agent.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ location: Pontevedra
2
+ date: 2025-03-22 12:00
3
+ max_driving_hours: 2
4
+ model_id: openai/gpt-3.5-turbo
5
+ api_key_var: OPENAI_API_KEY
6
+ agent_type: smolagents
7
+ # input_prompt_template:
src/surf_spot_finder/agents/openai.py CHANGED
@@ -17,6 +17,12 @@ from smolagents import (
17
  )
18
 
19
 
 
 
 
 
 
 
20
  @function_tool
21
  def search_web(query: str) -> str:
22
  """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
@@ -69,7 +75,7 @@ def run_openai_agent(
69
  model_id: str,
70
  prompt: str,
71
  name: str = "surf-spot-finder",
72
- instructions: Optional[str] = None,
73
  api_key_var: Optional[str] = None,
74
  base_url: Optional[str] = None,
75
  ) -> RunResult:
@@ -87,7 +93,7 @@ def run_openai_agent(
87
  prompt (str): The prompt to be given to the agent.
88
  name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
89
  instructions (Optional[str], optional): Initial instructions to give the agent.
90
- Defaults to None.
91
  api_key_var (Optional[str], optional): The name of the environment variable
92
  containing the OpenAI API key. If provided, along with `base_url`, an
93
  external OpenAI client will be used. Defaults to None.
@@ -126,27 +132,12 @@ def run_openai_agent(
126
  return result
127
 
128
 
129
- DEFAULT_MULTIAGENT_INSTRUCTIONS = """
130
- You will be asked to perform a task.
131
-
132
- Always follow this steps:
133
-
134
- First, before solving the task, look at the available agent/tools and plan a sequence of actions using the available tools.
135
- Second, show the plan of actions and ask for user verification. If the user does not verify the plan, come up with a better plan.
136
- Third, execute the plan using the available tools, until you get a final answer.
137
-
138
- Once you get a final answer, show it and ask for user verification. If the user does not verify the answer, come up with a better answer.
139
-
140
- Finally, use the available handoff tool (`transfer_to_<agent_name>`) to communicate it to the user.
141
- """
142
-
143
-
144
  @logger.catch(reraise=True)
145
  def run_openai_multi_agent(
146
  model_id: str,
147
  prompt: str,
148
  name: str = "surf-spot-finder",
149
- instructions: Optional[str] = DEFAULT_MULTIAGENT_INSTRUCTIONS,
150
  ) -> RunResult:
151
  """Runs multiple OpenAI agents orchestrated by a main agent.
152
 
@@ -162,7 +153,7 @@ def run_openai_multi_agent(
162
  prompt (str): The prompt to be given to the agent.
163
  name (str, optional): The name of the main agent. Defaults to "surf-spot-finder".
164
  instructions (Optional[str], optional): Initial instructions to give the agent.
165
- Defaults to [DEFAULT_MULTIAGENT_INSTRUCTIONS][surf_spot_finder.agents.openai.DEFAULT_MULTIAGENT_INSTRUCTIONS].
166
 
167
  Returns:
168
  RunResult: A RunResult object containing the output of the agent run.
 
17
  )
18
 
19
 
20
+ from surf_spot_finder.prompts.openai import (
21
+ SINGLE_AGENT_SYSTEM_PROMPT,
22
+ MULTI_AGENT_SYSTEM_PROMPT,
23
+ )
24
+
25
+
26
  @function_tool
27
  def search_web(query: str) -> str:
28
  """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
 
75
  model_id: str,
76
  prompt: str,
77
  name: str = "surf-spot-finder",
78
+ instructions: Optional[str] = SINGLE_AGENT_SYSTEM_PROMPT,
79
  api_key_var: Optional[str] = None,
80
  base_url: Optional[str] = None,
81
  ) -> RunResult:
 
93
  prompt (str): The prompt to be given to the agent.
94
  name (str, optional): The name of the agent. Defaults to "surf-spot-finder".
95
  instructions (Optional[str], optional): Initial instructions to give the agent.
96
+ Defaults to [SINGLE_AGENT_SYSTEM_PROMPT][surf_spot_finder.prompts.openai.SINGLE_AGENT_SYSTEM_PROMPT].
97
  api_key_var (Optional[str], optional): The name of the environment variable
98
  containing the OpenAI API key. If provided, along with `base_url`, an
99
  external OpenAI client will be used. Defaults to None.
 
132
  return result
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  @logger.catch(reraise=True)
136
  def run_openai_multi_agent(
137
  model_id: str,
138
  prompt: str,
139
  name: str = "surf-spot-finder",
140
+ instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
141
  ) -> RunResult:
142
  """Runs multiple OpenAI agents orchestrated by a main agent.
143
 
 
153
  prompt (str): The prompt to be given to the agent.
154
  name (str, optional): The name of the main agent. Defaults to "surf-spot-finder".
155
  instructions (Optional[str], optional): Initial instructions to give the agent.
156
+ Defaults to [MULTI_AGENT_SYSTEM_PROMPT][surf_spot_finder.prompts.openai.MULTI_AGENT_SYSTEM_PROMPT].
157
 
158
  Returns:
159
  RunResult: A RunResult object containing the output of the agent run.
src/surf_spot_finder/agents/smolagents.py CHANGED
@@ -1,10 +1,16 @@
1
  import os
2
- from typing import Optional, TYPE_CHECKING
3
 
4
  from loguru import logger
5
 
6
- if TYPE_CHECKING:
7
- from smolagents import CodeAgent
 
 
 
 
 
 
8
 
9
 
10
  @logger.catch(reraise=True)
@@ -13,7 +19,7 @@ def run_smolagent(
13
  prompt: str,
14
  api_key_var: Optional[str] = None,
15
  api_base: Optional[str] = None,
16
- ) -> "CodeAgent":
17
  """
18
  Create and configure a Smolagents CodeAgent with the specified model.
19
 
@@ -29,17 +35,10 @@ def run_smolagent(
29
  CodeAgent: Configured agent ready to process requests
30
 
31
  Example:
 
32
  >>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
33
  >>> agent.run("Find surf spots near San Diego")
34
  """
35
- from smolagents import ( # pylint: disable=import-outside-toplevel
36
- CodeAgent,
37
- DuckDuckGoSearchTool,
38
- LiteLLMModel,
39
- ToolCollection,
40
- )
41
- from mcp import StdioServerParameters
42
- from surf_spot_finder.agents.prompts.smolagents import SYSTEM_PROMPT
43
 
44
  model = LiteLLMModel(
45
  model_id=model_id,
 
1
  import os
2
+ from typing import Optional
3
 
4
  from loguru import logger
5
 
6
+ from smolagents import (
7
+ CodeAgent,
8
+ DuckDuckGoSearchTool,
9
+ LiteLLMModel,
10
+ ToolCollection,
11
+ )
12
+ from mcp import StdioServerParameters
13
+ from surf_spot_finder.prompts.smolagents import SYSTEM_PROMPT
14
 
15
 
16
  @logger.catch(reraise=True)
 
19
  prompt: str,
20
  api_key_var: Optional[str] = None,
21
  api_base: Optional[str] = None,
22
+ ) -> CodeAgent:
23
  """
24
  Create and configure a Smolagents CodeAgent with the specified model.
25
 
 
35
  CodeAgent: Configured agent ready to process requests
36
 
37
  Example:
38
+
39
  >>> agent = run_smolagent("anthropic/claude-3-haiku", "my prompt here", "ANTHROPIC_API_KEY", None, None)
40
  >>> agent.run("Find surf spots near San Diego")
41
  """
 
 
 
 
 
 
 
 
42
 
43
  model = LiteLLMModel(
44
  model_id=model_id,
src/surf_spot_finder/cli.py CHANGED
@@ -1,40 +1,72 @@
 
1
  from typing import Optional
2
 
 
3
  from fire import Fire
4
  from loguru import logger
5
 
6
  from surf_spot_finder.config import (
7
  Config,
8
- DEFAULT_PROMPT,
9
  )
10
  from surf_spot_finder.agents import RUNNERS
 
11
  from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
12
 
13
 
14
  @logger.catch(reraise=True)
15
  def find_surf_spot(
16
- location: str,
17
- date: str,
18
- max_driving_hours: int,
19
- model_id: str,
20
  agent_type: str = "smolagents",
21
  api_key_var: Optional[str] = None,
22
- prompt: str = DEFAULT_PROMPT,
23
  json_tracer: bool = True,
24
  api_base: Optional[str] = None,
 
25
  ):
26
- logger.info("Loading config")
27
- config = Config(
28
- location=location,
29
- date=date,
30
- max_driving_hours=max_driving_hours,
31
- model_id=model_id,
32
- agent_type=agent_type,
33
- api_key_var=api_key_var,
34
- prompt=prompt,
35
- json_tracer=json_tracer,
36
- api_base=api_base,
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  logger.info("Setting up tracing")
40
  tracer_provider, _ = get_tracer_provider(
@@ -45,7 +77,7 @@ def find_surf_spot(
45
  logger.info(f"Running {config.agent_type} agent")
46
  RUNNERS[config.agent_type](
47
  model_id=config.model_id,
48
- prompt=config.prompt.format(
49
  LOCATION=config.location,
50
  MAX_DRIVING_HOURS=config.max_driving_hours,
51
  DATE=config.date,
 
1
+ from pathlib import Path
2
  from typing import Optional
3
 
4
+ import yaml
5
  from fire import Fire
6
  from loguru import logger
7
 
8
  from surf_spot_finder.config import (
9
  Config,
 
10
  )
11
  from surf_spot_finder.agents import RUNNERS
12
+ from surf_spot_finder.prompts.shared import INPUT_PROMPT
13
  from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
14
 
15
 
16
  @logger.catch(reraise=True)
17
  def find_surf_spot(
18
+ location: Optional[str] = None,
19
+ date: Optional[str] = None,
20
+ max_driving_hours: Optional[int] = None,
21
+ model_id: Optional[str] = None,
22
  agent_type: str = "smolagents",
23
  api_key_var: Optional[str] = None,
24
+ input_prompt_template: str = INPUT_PROMPT,
25
  json_tracer: bool = True,
26
  api_base: Optional[str] = None,
27
+ from_config: Optional[str] = None,
28
  ):
29
+ """Find the best surf spot based on the given criteria.
30
+
31
+ Args:
32
+ location (str): The location to search around.
33
+ Required if `from_config` is not provided.
34
+ date (str): The date to search for.
35
+ Required if `from_config` is not provided.
36
+ max_driving_hours (int): The maximum driving hours from the location.
37
+ Required if `from_config` is not provided.
38
+ model_id (str): The ID of the model to use.
39
+ Required if `from_config` is not provided.
40
+
41
+ If using `agent_type=smolagents`, use LiteLLM syntax (e.g., 'openai/o1', 'anthropic/claude-3-sonnet').
42
+ If using `agent_type={openai,openai_multi_agent}`, use OpenAI syntax (e.g., 'o1').
43
+ agent_type (str, optional): The type of agent to use.
44
+ Must be one of the supported types in [RUNNERS][surf_spot_finder.agents.RUNNERS].
45
+ api_key_var (Optional[str], optional): The name of the environment variable containing the API key.
46
+ input_prompt_template (str, optional): The template for the imput_prompt.
47
+
48
+ Must contain the following placeholders: `{LOCATION}`, `{MAX_DRIVING_HOURS}`, and `{DATE}`.
49
+ json_tracer (bool, optional): Whether to use the custom JSON file exporter.
50
+ api_base (Optional[str], optional): The base URL for the API.
51
+ from_config (Optional[str], optional): Path to a YAML config file.
52
+
53
+ If provided, all other arguments will be ignored.
54
+ """
55
+ if from_config:
56
+ logger.info(f"Loading {from_config}")
57
+ config = Config.model_validate(yaml.safe_load(Path(from_config).read_text()))
58
+ else:
59
+ config = Config(
60
+ location=location,
61
+ date=date,
62
+ max_driving_hours=max_driving_hours,
63
+ model_id=model_id,
64
+ agent_type=agent_type,
65
+ api_key_var=api_key_var,
66
+ prompt=input_prompt_template,
67
+ json_tracer=json_tracer,
68
+ api_base=api_base,
69
+ )
70
 
71
  logger.info("Setting up tracing")
72
  tracer_provider, _ = get_tracer_provider(
 
77
  logger.info(f"Running {config.agent_type} agent")
78
  RUNNERS[config.agent_type](
79
  model_id=config.model_id,
80
+ prompt=config.input_prompt_template.format(
81
  LOCATION=config.location,
82
  MAX_DRIVING_HOURS=config.max_driving_hours,
83
  DATE=config.date,
src/surf_spot_finder/config.py CHANGED
@@ -1,14 +1,7 @@
1
  from typing import Annotated, Optional
2
  from pydantic import AfterValidator, BaseModel, FutureDatetime, PositiveInt
3
- from datetime import datetime
4
 
5
- CURRENT_DATE = datetime.now().strftime("%Y-%m-%d")
6
-
7
- DEFAULT_PROMPT = (
8
- "What will be the best surf spot around {LOCATION}"
9
- ", in a {MAX_DRIVING_HOURS} hour driving radius"
10
- ", at {DATE}?"
11
- )
12
 
13
 
14
  def validate_prompt(value) -> str:
@@ -26,7 +19,9 @@ def validate_agent_type(value) -> str:
26
 
27
 
28
  class Config(BaseModel):
29
- prompt: Annotated[str, AfterValidator(validate_prompt)]
 
 
30
  location: str
31
  max_driving_hours: PositiveInt
32
  date: FutureDatetime
 
1
  from typing import Annotated, Optional
2
  from pydantic import AfterValidator, BaseModel, FutureDatetime, PositiveInt
 
3
 
4
+ from surf_spot_finder.prompts.shared import INPUT_PROMPT
 
 
 
 
 
 
5
 
6
 
7
  def validate_prompt(value) -> str:
 
19
 
20
 
21
  class Config(BaseModel):
22
+ input_prompt_template: Annotated[str, AfterValidator(validate_prompt)] = (
23
+ INPUT_PROMPT
24
+ )
25
  location: str
26
  max_driving_hours: PositiveInt
27
  date: FutureDatetime
src/surf_spot_finder/prompts/__init__.py ADDED
File without changes
src/surf_spot_finder/prompts/openai.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SINGLE_AGENT_SYSTEM_PROMPT = """
2
+ You will be asked to perform a task.
3
+
4
+ Before solving the task, plan a sequence of actions using the available tools.
5
+ Then, execute the sequence of actions using the tools.
6
+ """.strip()
7
+
8
+ MULTI_AGENT_SYSTEM_PROMPT = """
9
+ You will be asked to perform a task.
10
+
11
+ Always follow this steps:
12
+
13
+ First, before solving the task, plan a sequence of actions using the available tools.
14
+ Second, show the plan of actions and ask for user verification. If the user does not verify the plan, come up with a better plan.
15
+ Third, execute the plan using the available tools, until you get a final answer.
16
+
17
+ Once you get a final answer, show it and ask for user verification. If the user does not verify the answer, come up with a better answer.
18
+
19
+ Finally, use the available handoff tool (`transfer_to_<agent_name>`) to communicate it to the user.
20
+ """.strip()
src/surf_spot_finder/prompts/shared.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ INPUT_PROMPT = """
2
+ According to the forecast, what will be the best spot to surf around {LOCATION},
3
+ in a {MAX_DRIVING_HOURS} hour driving radius,
4
+ at {DATE}?"
5
+ """.strip()
src/surf_spot_finder/{agents/prompts → prompts}/smolagents.py RENAMED
File without changes
tests/unit/agents/test_unit_openai.py CHANGED
@@ -9,7 +9,10 @@ from surf_spot_finder.agents.openai import (
9
  search_web,
10
  user_verification,
11
  visit_webpage,
12
- DEFAULT_MULTIAGENT_INSTRUCTIONS,
 
 
 
13
  )
14
 
15
 
@@ -23,7 +26,7 @@ def test_run_openai_agent_default():
23
  run_openai_agent("gpt-4o", "Test prompt")
24
  mock_agent.assert_called_once_with(
25
  model="gpt-4o",
26
- instructions=None,
27
  name="surf-spot-finder",
28
  tools=[search_web, visit_webpage],
29
  )
@@ -91,7 +94,7 @@ def test_run_openai_multiagent():
91
 
92
  mock_agent.assert_any_call(
93
  model="gpt-4o",
94
- instructions=DEFAULT_MULTIAGENT_INSTRUCTIONS,
95
  name="surf-spot-finder",
96
  # TODO: add more elaborated checks
97
  handoffs=ANY,
 
9
  search_web,
10
  user_verification,
11
  visit_webpage,
12
+ )
13
+ from surf_spot_finder.prompts.openai import (
14
+ SINGLE_AGENT_SYSTEM_PROMPT,
15
+ MULTI_AGENT_SYSTEM_PROMPT,
16
  )
17
 
18
 
 
26
  run_openai_agent("gpt-4o", "Test prompt")
27
  mock_agent.assert_called_once_with(
28
  model="gpt-4o",
29
+ instructions=SINGLE_AGENT_SYSTEM_PROMPT,
30
  name="surf-spot-finder",
31
  tools=[search_web, visit_webpage],
32
  )
 
94
 
95
  mock_agent.assert_any_call(
96
  model="gpt-4o",
97
+ instructions=MULTI_AGENT_SYSTEM_PROMPT,
98
  name="surf-spot-finder",
99
  # TODO: add more elaborated checks
100
  handoffs=ANY,
tests/unit/agents/test_unit_smolagents.py CHANGED
@@ -1,103 +1,84 @@
1
  import os
2
  import pytest
3
  from unittest.mock import patch, MagicMock
 
4
 
5
  from surf_spot_finder.agents.smolagents import run_smolagent
6
 
7
 
8
  @pytest.fixture
9
- def mock_smolagents_imports():
10
- """Mock the smolagents imports to avoid actual instantiation."""
11
- mock_code_agent = MagicMock()
12
- mock_ddg_tool = MagicMock()
13
- mock_litellm_model = MagicMock()
14
  mock_tool_collection = MagicMock()
15
 
16
- # Configure the mock tool collection to work as a context manager
17
  mock_tool_collection.from_mcp.return_value.__enter__.return_value = (
18
  mock_tool_collection
19
  )
20
  mock_tool_collection.from_mcp.return_value.__exit__.return_value = None
21
  mock_tool_collection.tools = ["mock_tool"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- with patch.dict(
24
- "sys.modules",
25
- {
26
- "smolagents": MagicMock(
27
- CodeAgent=mock_code_agent,
28
- DuckDuckGoSearchTool=mock_ddg_tool,
29
- LiteLLMModel=mock_litellm_model,
30
- ToolCollection=mock_tool_collection,
31
- ),
32
- "mcp": MagicMock(
33
- StdioServerParameters=MagicMock(),
34
- ),
35
- },
36
- ):
37
- yield {
38
- "CodeAgent": mock_code_agent,
39
- "DuckDuckGoSearchTool": mock_ddg_tool,
40
- "LiteLLMModel": mock_litellm_model,
41
- "ToolCollection": mock_tool_collection,
42
- }
43
-
44
-
45
- @pytest.mark.usefixtures("mock_smolagents_imports")
46
- def test_run_smolagent_with_api_key_var():
47
- """Test smolagent creation with an API key from environment variable."""
48
- # The patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"})
49
- # is a testing construct that temporarily modifies the environment variables
50
- # for the duration of the test.
51
- # some tests use TEST_API_KEY while others don't
52
- with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
53
- from smolagents import CodeAgent, LiteLLMModel
54
 
 
55
  run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
56
 
57
- LiteLLMModel.assert_called()
58
- model_call_kwargs = LiteLLMModel.call_args[1]
59
  assert model_call_kwargs["model_id"] == "openai/gpt-4"
60
  assert model_call_kwargs["api_key"] == "test-key-12345"
61
  assert model_call_kwargs["api_base"] is None
62
 
63
- CodeAgent.assert_called_once()
64
- CodeAgent.return_value.run.assert_called_once_with("Test prompt")
65
 
66
 
67
- @pytest.mark.usefixtures("mock_smolagents_imports")
68
- def test_run_smolagent_with_custom_api_base():
69
- """Test smolagent creation with a custom API base."""
70
- with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
71
- from smolagents import LiteLLMModel
72
 
73
- # Act
74
  run_smolagent(
75
  "anthropic/claude-3-sonnet",
76
  "Test prompt",
77
  api_key_var="TEST_API_KEY",
78
  api_base="https://custom-api.example.com",
79
  )
80
- last_call = LiteLLMModel.call_args_list[-1]
81
 
82
  assert last_call[1]["model_id"] == "anthropic/claude-3-sonnet"
83
  assert last_call[1]["api_key"] == "test-key-12345"
84
  assert last_call[1]["api_base"] == "https://custom-api.example.com"
85
 
86
 
87
- @pytest.mark.usefixtures("mock_smolagents_imports")
88
- def test_run_smolagent_without_api_key():
89
- """You should be able to run the smolagent without an API key."""
90
- from smolagents import LiteLLMModel
91
 
92
- run_smolagent("ollama_chat/deepseek-r1", "Test prompt")
 
93
 
94
- last_call = LiteLLMModel.call_args_list[-1]
95
  assert last_call[1]["model_id"] == "ollama_chat/deepseek-r1"
96
  assert last_call[1]["api_key"] is None
97
 
98
 
99
  def test_run_smolagent_environment_error():
100
- """Test that passing a bad api_key_var throws an error"""
101
  with patch.dict(os.environ, {}, clear=True):
102
  with pytest.raises(KeyError, match="MISSING_KEY"):
103
  run_smolagent("test-model", "Test prompt", api_key_var="MISSING_KEY")
 
1
  import os
2
  import pytest
3
  from unittest.mock import patch, MagicMock
4
+ import contextlib
5
 
6
  from surf_spot_finder.agents.smolagents import run_smolagent
7
 
8
 
9
  @pytest.fixture
10
+ def common_patches():
11
+ litellm_model_mock = MagicMock()
12
+ code_agent_mock = MagicMock()
13
+ patch_context = contextlib.ExitStack()
 
14
  mock_tool_collection = MagicMock()
15
 
 
16
  mock_tool_collection.from_mcp.return_value.__enter__.return_value = (
17
  mock_tool_collection
18
  )
19
  mock_tool_collection.from_mcp.return_value.__exit__.return_value = None
20
  mock_tool_collection.tools = ["mock_tool"]
21
+ patch_context.enter_context(
22
+ patch("surf_spot_finder.agents.smolagents.StdioServerParameters", MagicMock())
23
+ )
24
+ patch_context.enter_context(
25
+ patch("surf_spot_finder.agents.smolagents.CodeAgent", code_agent_mock)
26
+ )
27
+ patch_context.enter_context(
28
+ patch("surf_spot_finder.agents.smolagents.LiteLLMModel", litellm_model_mock)
29
+ )
30
+ patch_context.enter_context(
31
+ patch("surf_spot_finder.agents.smolagents.ToolCollection", mock_tool_collection)
32
+ )
33
+ yield patch_context, litellm_model_mock, code_agent_mock, mock_tool_collection
34
+ patch_context.close()
35
+
36
 
37
+ def test_run_smolagent_with_api_key_var(common_patches):
38
+ patch_context, litellm_model_mock, code_agent_mock, *_ = common_patches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
41
  run_smolagent("openai/gpt-4", "Test prompt", api_key_var="TEST_API_KEY")
42
 
43
+ litellm_model_mock.assert_called()
44
+ model_call_kwargs = litellm_model_mock.call_args[1]
45
  assert model_call_kwargs["model_id"] == "openai/gpt-4"
46
  assert model_call_kwargs["api_key"] == "test-key-12345"
47
  assert model_call_kwargs["api_base"] is None
48
 
49
+ code_agent_mock.assert_called_once()
50
+ code_agent_mock.return_value.run.assert_called_once_with("Test prompt")
51
 
52
 
53
+ def test_run_smolagent_with_custom_api_base(common_patches):
54
+ patch_context, litellm_model_mock, *_ = common_patches
 
 
 
55
 
56
+ with patch_context, patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
57
  run_smolagent(
58
  "anthropic/claude-3-sonnet",
59
  "Test prompt",
60
  api_key_var="TEST_API_KEY",
61
  api_base="https://custom-api.example.com",
62
  )
63
+ last_call = litellm_model_mock.call_args_list[-1]
64
 
65
  assert last_call[1]["model_id"] == "anthropic/claude-3-sonnet"
66
  assert last_call[1]["api_key"] == "test-key-12345"
67
  assert last_call[1]["api_base"] == "https://custom-api.example.com"
68
 
69
 
70
+ def test_run_smolagent_without_api_key(common_patches):
71
+ patch_context, litellm_model_mock, *_ = common_patches
 
 
72
 
73
+ with patch_context:
74
+ run_smolagent("ollama_chat/deepseek-r1", "Test prompt")
75
 
76
+ last_call = litellm_model_mock.call_args_list[-1]
77
  assert last_call[1]["model_id"] == "ollama_chat/deepseek-r1"
78
  assert last_call[1]["api_key"] is None
79
 
80
 
81
  def test_run_smolagent_environment_error():
 
82
  with patch.dict(os.environ, {}, clear=True):
83
  with pytest.raises(KeyError, match="MISSING_KEY"):
84
  run_smolagent("test-model", "Test prompt", api_key_var="MISSING_KEY")