David de la Iglesia Castro commited on
Commit
ee1c4f8
·
unverified ·
1 Parent(s): 79792df

15 implement simple multi agent workflow (#16)

Browse files

* Add run_openai_multi_agent.

* config: Simplify DEFAULT_PROMPT.

* fix(test_unit_openai): Import after patch

* Update insrtuctions. Add communication agent

* Iterate on instructions

* enh(tracing): Handle openai_multi_agent. Add `test_setup_tracing`.

* More instructions tuning

* Add test_run_openai_multiagent

* More tweaks

docs/api.md CHANGED
@@ -6,6 +6,8 @@
6
 
7
  ::: surf_spot_finder.agents.openai
8
 
 
 
9
  ::: surf_spot_finder.agents.smolagents
10
 
11
  ::: surf_spot_finder.tracing
 
6
 
7
  ::: surf_spot_finder.agents.openai
8
 
9
+ ::: surf_spot_finder.agents.openai.DEFAULT_MULTIAGENT_INSTRUCTIONS
10
+
11
  ::: surf_spot_finder.agents.smolagents
12
 
13
  ::: surf_spot_finder.tracing
src/surf_spot_finder/agents/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
- from .openai import run_openai_agent
2
  from .smolagents import run_smolagent
3
 
4
  RUNNERS = {
5
  "openai": run_openai_agent,
6
  "smolagents": run_smolagent,
 
7
  }
8
 
9
 
 
1
+ from .openai import run_openai_agent, run_openai_multi_agent
2
  from .smolagents import run_smolagent
3
 
4
  RUNNERS = {
5
  "openai": run_openai_agent,
6
  "smolagents": run_smolagent,
7
+ "openai_multi_agent": run_openai_multi_agent,
8
  }
9
 
10
 
src/surf_spot_finder/agents/openai.py CHANGED
@@ -1,10 +1,67 @@
1
  import os
2
- from typing import Optional, TYPE_CHECKING
3
 
 
 
 
 
 
 
 
 
4
  from loguru import logger
 
 
 
 
 
5
 
6
- if TYPE_CHECKING:
7
- from agents import RunResult
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  @logger.catch(reraise=True)
@@ -15,7 +72,7 @@ def run_openai_agent(
15
  instructions: Optional[str] = None,
16
  api_key_var: Optional[str] = None,
17
  base_url: Optional[str] = None,
18
- ) -> "RunResult":
19
  """Runs an OpenAI agent with the given prompt and configuration.
20
 
21
  It leverages the 'agents' library to create and manage the agent
@@ -42,34 +99,6 @@ def run_openai_agent(
42
  RunResult: A RunResult object containing the output of the agent run.
43
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
44
  """
45
- from agents import (
46
- Agent,
47
- AsyncOpenAI,
48
- OpenAIChatCompletionsModel,
49
- Runner,
50
- function_tool,
51
- )
52
- from smolagents import DuckDuckGoSearchTool, VisitWebpageTool
53
-
54
- @function_tool
55
- def search_web(query: str) -> str:
56
- """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
57
-
58
- Args:
59
- query: The search query to perform.
60
- """
61
- search_tool = DuckDuckGoSearchTool()
62
- return search_tool.forward(query)
63
-
64
- @function_tool
65
- def visit_webpage(url: str) -> str:
66
- """Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages.
67
-
68
- Args:
69
- url: The url of the webpage to visit.
70
- """
71
- visit_tool = VisitWebpageTool()
72
- return visit_tool.forward(url)
73
 
74
  if api_key_var and base_url:
75
  external_client = AsyncOpenAI(
@@ -95,3 +124,88 @@ def run_openai_agent(
95
  result = Runner.run_sync(agent, prompt)
96
  logger.info(result.final_output)
97
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Optional
3
 
4
+ from agents import (
5
+ Agent,
6
+ AsyncOpenAI,
7
+ OpenAIChatCompletionsModel,
8
+ Runner,
9
+ RunResult,
10
+ function_tool,
11
+ )
12
  from loguru import logger
13
+ from smolagents import (
14
+ DuckDuckGoSearchTool,
15
+ VisitWebpageTool,
16
+ FinalAnswerTool,
17
+ )
18
 
19
+
20
+ @function_tool
21
+ def search_web(query: str) -> str:
22
+ """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results.
23
+
24
+ Args:
25
+ query: The search query to perform.
26
+ """
27
+ logger.debug(f"Calling search_web: {query}")
28
+ search_tool = DuckDuckGoSearchTool()
29
+ return search_tool.forward(query)
30
+
31
+
32
+ @function_tool
33
+ def visit_webpage(url: str) -> str:
34
+ """Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages.
35
+
36
+ Args:
37
+ url: The url of the webpage to visit.
38
+ """
39
+ logger.debug(f"Calling visit_webpage: {url}")
40
+ visit_tool = VisitWebpageTool()
41
+ return visit_tool.forward(url)
42
+
43
+
44
+ @function_tool
45
+ def final_answer(answer: str) -> str:
46
+ """Provides a final answer to the given problem.
47
+
48
+ Args:
49
+ answer: The answer to the problem.
50
+ """
51
+ logger.debug("Calling final_answer")
52
+ final_answer_tool = FinalAnswerTool()
53
+ return final_answer_tool.forward(answer)
54
+
55
+
56
+ @function_tool
57
+ def user_verification(query: str) -> str:
58
+ """Asks user to verify the given `query`.
59
+
60
+ Args:
61
+ query: The question that requires verification.
62
+ """
63
+ logger.debug("Calling user_verification")
64
+ return input(f"{query} => Type your answer here:")
65
 
66
 
67
  @logger.catch(reraise=True)
 
72
  instructions: Optional[str] = None,
73
  api_key_var: Optional[str] = None,
74
  base_url: Optional[str] = None,
75
+ ) -> RunResult:
76
  """Runs an OpenAI agent with the given prompt and configuration.
77
 
78
  It leverages the 'agents' library to create and manage the agent
 
99
  RunResult: A RunResult object containing the output of the agent run.
100
  See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
101
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if api_key_var and base_url:
104
  external_client = AsyncOpenAI(
 
124
  result = Runner.run_sync(agent, prompt)
125
  logger.info(result.final_output)
126
  return result
127
+
128
+
129
+ DEFAULT_MULTIAGENT_INSTRUCTIONS = """
130
+ You will be asked to perform a task.
131
+
132
+ Always follow this steps:
133
+
134
+ First, before solving the task, look at the available agent/tools and plan a sequence of actions using the available tools.
135
+ Second, show the plan of actions and ask for user verification. If the user does not verify the plan, come up with a better plan.
136
+ Third, execute the plan using the available tools, until you get a final answer.
137
+
138
+ Once you get a final answer, show it and ask for user verification. If the user does not verify the answer, come up with a better answer.
139
+
140
+ Finally, use the available handoff tool (`transfer_to_<agent_name>`) to communicate it to the user.
141
+ """
142
+
143
+
144
+ @logger.catch(reraise=True)
145
+ def run_openai_multi_agent(
146
+ model_id: str,
147
+ prompt: str,
148
+ name: str = "surf-spot-finder",
149
+ instructions: Optional[str] = DEFAULT_MULTIAGENT_INSTRUCTIONS,
150
+ ) -> RunResult:
151
+ """Runs multiple OpenAI agents orchestrated by a main agent.
152
+
153
+ It leverages the 'agents' library to create and manage the agent
154
+ execution.
155
+
156
+ See https://openai.github.io/openai-agents-python/ref/agent/ for more details.
157
+
158
+
159
+ Args:
160
+ model_id (str): The ID of the OpenAI model to use (e.g., "gpt4o").
161
+ See https://platform.openai.com/docs/api-reference/models.
162
+ prompt (str): The prompt to be given to the agent.
163
+ name (str, optional): The name of the main agent. Defaults to "surf-spot-finder".
164
+ instructions (Optional[str], optional): Initial instructions to give the agent.
165
+ Defaults to [DEFAULT_MULTIAGENT_INSTRUCTIONS][surf_spot_finder.agents.openai.DEFAULT_MULTIAGENT_INSTRUCTIONS].
166
+
167
+ Returns:
168
+ RunResult: A RunResult object containing the output of the agent run.
169
+ See https://openai.github.io/openai-agents-python/ref/result/#agents.result.RunResult.
170
+ """
171
+ user_verification_agent = Agent(
172
+ model=model_id,
173
+ instructions="Display the current output to the user, then ask for verification.",
174
+ name="user-verification-agent",
175
+ tools=[user_verification],
176
+ )
177
+
178
+ search_web_agent = Agent(
179
+ model=model_id,
180
+ instructions="Find relevant information about the provided task by combining web searches with visiting webpages.",
181
+ name="search-web-agent",
182
+ tools=[search_web, visit_webpage],
183
+ )
184
+
185
+ communication_agent = Agent(
186
+ model=model_id,
187
+ instructions=None,
188
+ name="communication-agent",
189
+ tools=[final_answer],
190
+ )
191
+
192
+ main_agent = Agent(
193
+ model=model_id,
194
+ instructions=instructions,
195
+ name=name,
196
+ handoffs=[communication_agent],
197
+ tools=[
198
+ search_web_agent.as_tool(
199
+ tool_name="search_web_with_agent",
200
+ tool_description=search_web_agent.instructions,
201
+ ),
202
+ user_verification_agent.as_tool(
203
+ tool_name="ask_user_verification_with_agent",
204
+ tool_description=user_verification_agent.instructions,
205
+ ),
206
+ ],
207
+ )
208
+
209
+ result = Runner.run_sync(main_agent, prompt)
210
+ logger.info(result.final_output)
211
+ return result
src/surf_spot_finder/config.py CHANGED
@@ -7,12 +7,7 @@ CURRENT_DATE = datetime.now().strftime("%Y-%m-%d")
7
  DEFAULT_PROMPT = (
8
  "What will be the best surf spot around {LOCATION}"
9
  ", in a {MAX_DRIVING_HOURS} hour driving radius"
10
- ", at {DATE}? it is currently "
11
- + CURRENT_DATE
12
- + ". find me the best surf spot and also report back"
13
- " on the expected water temperature and wave height."
14
- " Please remember that doing a google/duckduckgo search may be useful for finding which sites are relevant,"
15
- " but the final answer should be based on information retrieved from https://www.surf-forecast.com."
16
  )
17
 
18
 
 
7
  DEFAULT_PROMPT = (
8
  "What will be the best surf spot around {LOCATION}"
9
  ", in a {MAX_DRIVING_HOURS} hour driving radius"
10
+ ", at {DATE}?"
 
 
 
 
 
11
  )
12
 
13
 
src/surf_spot_finder/tracing.py CHANGED
@@ -95,7 +95,7 @@ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
95
 
96
  validate_agent_type(agent_type)
97
 
98
- if agent_type == "openai":
99
  from openinference.instrumentation.openai import OpenAIInstrumentor
100
 
101
  OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
 
95
 
96
  validate_agent_type(agent_type)
97
 
98
+ if "openai" in agent_type:
99
  from openinference.instrumentation.openai import OpenAIInstrumentor
100
 
101
  OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)
tests/unit/agents/test_unit_openai.py CHANGED
@@ -2,56 +2,98 @@ import os
2
  import pytest
3
  from unittest.mock import patch, MagicMock, ANY
4
 
5
- from surf_spot_finder.agents.openai import run_openai_agent
6
-
7
-
8
- @pytest.fixture
9
- def mock_agents_module():
10
- agents_mocks = {
11
- name: MagicMock()
12
- for name in (
13
- "Agent",
14
- "AsyncOpenAI",
15
- "OpenAIChatCompletionsModel",
16
- "Runner",
17
- "WebSearchTool",
18
- )
19
- }
20
- with patch.dict(
21
- "sys.modules",
22
- {
23
- "agents": MagicMock(**agents_mocks),
24
- },
25
- ):
26
- yield agents_mocks
27
 
 
 
28
 
29
- def test_run_openai_agent_default(mock_agents_module):
30
- run_openai_agent("gpt-4o", "Test prompt")
31
- mock_agents_module["Agent"].assert_called_once_with(
32
- model="gpt-4o",
33
- instructions=None,
34
- name="surf-spot-finder",
35
- tools=ANY,
36
- )
 
 
 
37
 
38
 
39
- def test_run_openai_agent_base_url_and_api_key_var(mock_agents_module):
40
- with patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}):
 
 
 
 
 
 
 
 
 
 
 
41
  run_openai_agent(
42
  "gpt-4o", "Test prompt", base_url="FOO", api_key_var="TEST_API_KEY"
43
  )
44
- mock_agents_module["AsyncOpenAI"].assert_called_once_with(
45
  api_key="test-key-12345",
46
  base_url="FOO",
47
  )
48
- mock_agents_module["OpenAIChatCompletionsModel"].assert_called_once()
49
 
50
 
51
- def test_run_smolagent_environment_error():
52
- """Test that passing a bad api_key_var throws an error"""
53
  with patch.dict(os.environ, {}, clear=True):
54
  with pytest.raises(KeyError, match="MISSING_KEY"):
55
  run_openai_agent(
56
  "test-model", "Test prompt", base_url="FOO", api_key_var="MISSING_KEY"
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pytest
3
  from unittest.mock import patch, MagicMock, ANY
4
 
5
+ from surf_spot_finder.agents.openai import (
6
+ final_answer,
7
+ run_openai_agent,
8
+ run_openai_multi_agent,
9
+ search_web,
10
+ user_verification,
11
+ visit_webpage,
12
+ DEFAULT_MULTIAGENT_INSTRUCTIONS,
13
+ )
14
+
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def test_run_openai_agent_default():
17
+ mock_agent = MagicMock()
18
 
19
+ with (
20
+ patch("surf_spot_finder.agents.openai.Agent", mock_agent),
21
+ patch("surf_spot_finder.agents.openai.Runner", MagicMock()),
22
+ ):
23
+ run_openai_agent("gpt-4o", "Test prompt")
24
+ mock_agent.assert_called_once_with(
25
+ model="gpt-4o",
26
+ instructions=None,
27
+ name="surf-spot-finder",
28
+ tools=[search_web, visit_webpage],
29
+ )
30
 
31
 
32
+ def test_run_openai_agent_base_url_and_api_key_var():
33
+ async_openai_mock = MagicMock()
34
+ openai_chat_completions_model = MagicMock()
35
+ with (
36
+ patch("surf_spot_finder.agents.openai.Agent", MagicMock()),
37
+ patch("surf_spot_finder.agents.openai.Runner", MagicMock()),
38
+ patch("surf_spot_finder.agents.openai.AsyncOpenAI", async_openai_mock),
39
+ patch(
40
+ "surf_spot_finder.agents.openai.OpenAIChatCompletionsModel",
41
+ openai_chat_completions_model,
42
+ ),
43
+ patch.dict(os.environ, {"TEST_API_KEY": "test-key-12345"}),
44
+ ):
45
  run_openai_agent(
46
  "gpt-4o", "Test prompt", base_url="FOO", api_key_var="TEST_API_KEY"
47
  )
48
+ async_openai_mock.assert_called_once_with(
49
  api_key="test-key-12345",
50
  base_url="FOO",
51
  )
52
+ openai_chat_completions_model.assert_called_once()
53
 
54
 
55
+ def test_run_openai_environment_error():
 
56
  with patch.dict(os.environ, {}, clear=True):
57
  with pytest.raises(KeyError, match="MISSING_KEY"):
58
  run_openai_agent(
59
  "test-model", "Test prompt", base_url="FOO", api_key_var="MISSING_KEY"
60
  )
61
+
62
+
63
+ def test_run_openai_multiagent():
64
+ mock_agent = MagicMock()
65
+
66
+ with (
67
+ patch("surf_spot_finder.agents.openai.Agent", mock_agent),
68
+ patch("surf_spot_finder.agents.openai.Runner", MagicMock()),
69
+ ):
70
+ run_openai_multi_agent("gpt-4o", "Test prompt")
71
+ mock_agent.assert_any_call(
72
+ model="gpt-4o",
73
+ instructions="Display the current output to the user, then ask for verification.",
74
+ name="user-verification-agent",
75
+ tools=[user_verification],
76
+ )
77
+
78
+ mock_agent.assert_any_call(
79
+ model="gpt-4o",
80
+ instructions="Find relevant information about the provided task by combining web searches with visiting webpages.",
81
+ name="search-web-agent",
82
+ tools=[search_web, visit_webpage],
83
+ )
84
+
85
+ mock_agent.assert_any_call(
86
+ model="gpt-4o",
87
+ instructions=None,
88
+ name="communication-agent",
89
+ tools=[final_answer],
90
+ )
91
+
92
+ mock_agent.assert_any_call(
93
+ model="gpt-4o",
94
+ instructions=DEFAULT_MULTIAGENT_INSTRUCTIONS,
95
+ name="surf-spot-finder",
96
+ # TODO: add more elaborated checks
97
+ handoffs=ANY,
98
+ tools=ANY,
99
+ )
tests/unit/test_unit_tracing.py CHANGED
@@ -32,6 +32,20 @@ def test_get_tracer_provider(tmp_path, json_tracer):
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def test_invalid_agent_type():
36
  with pytest.raises(ValueError, match="agent_type must be one of"):
37
  setup_tracing(MagicMock(), "invalid_agent_type")
 
32
  )
33
 
34
 
35
+ @pytest.mark.parametrize(
36
+ "agent_type,instrumentor",
37
+ [
38
+ ("openai", "openai.OpenAIInstrumentor"),
39
+ ("openai_multi_agent", "openai.OpenAIInstrumentor"),
40
+ ("smolagents", "smolagents.SmolagentsInstrumentor"),
41
+ ],
42
+ )
43
+ def test_setup_tracing(agent_type, instrumentor):
44
+ with patch(f"openinference.instrumentation.{instrumentor}") as mock_instrumentor:
45
+ setup_tracing(MagicMock(), agent_type)
46
+ mock_instrumentor.assert_called_once()
47
+
48
+
49
  def test_invalid_agent_type():
50
  with pytest.raises(ValueError, match="agent_type must be one of"):
51
  setup_tracing(MagicMock(), "invalid_agent_type")