Nathan Brake commited on
Commit
7c69831
·
unverified ·
1 Parent(s): 9d27e7a

Uniform Trace extraction and inferring agent_type (#27)

Browse files

* Need to re-tool the eval files but at least now evidence extraction is possible

* Linting

* linting

* fix unit tests

* lint

* patch langchain and openai output format

* lint

* fix pyproj

pyproject.toml CHANGED
@@ -20,6 +20,7 @@ dependencies = [
20
  langchain = [
21
  "langchain",
22
  "langgraph",
 
23
  "openinference-instrumentation-langchain"
24
  ]
25
  smolagents = [
@@ -29,7 +30,7 @@ smolagents = [
29
 
30
  openai = [
31
  "openai-agents",
32
- "openinference-instrumentation-openai-agents"
33
  ]
34
 
35
  mcp = [
 
20
  langchain = [
21
  "langchain",
22
  "langgraph",
23
+ "langchain-openai>=0.3.9",
24
  "openinference-instrumentation-langchain"
25
  ]
26
  smolagents = [
 
30
 
31
  openai = [
32
  "openai-agents",
33
+ "openinference-instrumentation-openai-agents>=0.1.2"
34
  ]
35
 
36
  mcp = [
src/surf_spot_finder/agents/__init__.py CHANGED
@@ -1,15 +1,30 @@
 
1
  from .langchain import run_lanchain_agent
2
  from .openai import run_openai_agent, run_openai_multi_agent
3
  from .smolagents import run_smolagent
4
 
 
 
 
 
 
 
 
 
 
5
  RUNNERS = {
6
- "langchain": run_lanchain_agent,
7
- "openai": run_openai_agent,
8
- "smolagents": run_smolagent,
9
- "openai_multi_agent": run_openai_multi_agent,
10
  }
11
 
12
 
13
- def validate_agent_type(value) -> str:
14
- if value not in RUNNERS:
15
- raise ValueError(f"agent_type must be one of {RUNNERS.keys()}")
 
 
 
 
 
 
1
+ from enum import Enum
2
  from .langchain import run_lanchain_agent
3
  from .openai import run_openai_agent, run_openai_multi_agent
4
  from .smolagents import run_smolagent
5
 
6
+
7
+ # Define the available agent type enums
8
+ class AgentType(str, Enum):
9
+ LANGCHAIN = "langchain"
10
+ OPENAI = "openai"
11
+ OPENAI_MULTI_AGENT = "openai_multi_agent"
12
+ SMOLAGENTS = "smolagents"
13
+
14
+
15
  RUNNERS = {
16
+ AgentType.LANGCHAIN: run_lanchain_agent,
17
+ AgentType.OPENAI: run_openai_agent,
18
+ AgentType.SMOLAGENTS: run_smolagent,
19
+ AgentType.OPENAI_MULTI_AGENT: run_openai_multi_agent,
20
  }
21
 
22
 
23
+ def validate_agent_type(value: str) -> str:
24
+ try:
25
+ agent_type = AgentType(value)
26
+ if agent_type not in RUNNERS:
27
+ raise ValueError(f"agent_type {value} is valid but has no runner")
28
+ return value
29
+ except ValueError:
30
+ raise ValueError(f"agent_type must be one of {[e.value for e in AgentType]}")
src/surf_spot_finder/agents/langchain.py CHANGED
@@ -14,6 +14,8 @@ try:
14
  except ImportError:
15
  langchain_available = False
16
 
 
 
17
 
18
  @logger.catch(reraise=True)
19
  def run_lanchain_agent(
@@ -52,13 +54,14 @@ def run_lanchain_agent(
52
 
53
  model = init_chat_model(model_id)
54
  agent = create_react_agent(
55
- model=model,
56
- tools=imported_tools,
57
- checkpointer=MemorySaver(),
58
  )
59
  for step in agent.stream(
60
  {"messages": [HumanMessage(content=prompt)]},
61
- {"configurable": {"thread_id": "abc123"}},
 
 
 
62
  stream_mode="values",
63
  ):
64
  step["messages"][-1].pretty_print()
 
14
  except ImportError:
15
  langchain_available = False
16
 
17
+ DEFAULT_RECURSION_LIMIT = 50
18
+
19
 
20
  @logger.catch(reraise=True)
21
  def run_lanchain_agent(
 
54
 
55
  model = init_chat_model(model_id)
56
  agent = create_react_agent(
57
+ model=model, tools=imported_tools, checkpointer=MemorySaver()
 
 
58
  )
59
  for step in agent.stream(
60
  {"messages": [HumanMessage(content=prompt)]},
61
+ {
62
+ "configurable": {"thread_id": "abc123"},
63
+ "recursion_limit": DEFAULT_RECURSION_LIMIT,
64
+ },
65
  stream_mode="values",
66
  ):
67
  step["messages"][-1].pretty_print()
src/surf_spot_finder/agents/openai.py CHANGED
@@ -24,6 +24,8 @@ try:
24
  except ImportError:
25
  agents_available = None
26
 
 
 
27
 
28
  @logger.catch(reraise=True)
29
  def run_openai_agent(
@@ -34,6 +36,7 @@ def run_openai_agent(
34
  api_key_var: Optional[str] = None,
35
  api_base: Optional[str] = None,
36
  tools: Optional[list[str]] = None,
 
37
  ) -> RunResult:
38
  """Runs an OpenAI agent with the given prompt and configuration.
39
 
@@ -94,7 +97,7 @@ def run_openai_agent(
94
  name=name,
95
  tools=imported_tools,
96
  )
97
- result = Runner.run_sync(agent, prompt)
98
  logger.info(result.final_output)
99
  return result
100
 
@@ -105,6 +108,7 @@ def run_openai_multi_agent(
105
  prompt: str,
106
  name: str = "surf-spot-finder",
107
  instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
 
108
  **kwargs,
109
  ) -> RunResult:
110
  """Runs multiple OpenAI agents orchestrated by a main agent.
@@ -176,6 +180,8 @@ def run_openai_multi_agent(
176
  ],
177
  )
178
 
179
- result = Runner.run_sync(main_agent, prompt)
 
 
180
  logger.info(result.final_output)
181
  return result
 
24
  except ImportError:
25
  agents_available = None
26
 
27
+ DEFAULT_MAX_TURNS = 20
28
+
29
 
30
  @logger.catch(reraise=True)
31
  def run_openai_agent(
 
36
  api_key_var: Optional[str] = None,
37
  api_base: Optional[str] = None,
38
  tools: Optional[list[str]] = None,
39
+ max_turns: Optional[int] = DEFAULT_MAX_TURNS,
40
  ) -> RunResult:
41
  """Runs an OpenAI agent with the given prompt and configuration.
42
 
 
97
  name=name,
98
  tools=imported_tools,
99
  )
100
+ result = Runner.run_sync(starting_agent=agent, input=prompt, max_turns=max_turns)
101
  logger.info(result.final_output)
102
  return result
103
 
 
108
  prompt: str,
109
  name: str = "surf-spot-finder",
110
  instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
111
+ max_turns: Optional[int] = DEFAULT_MAX_TURNS,
112
  **kwargs,
113
  ) -> RunResult:
114
  """Runs multiple OpenAI agents orchestrated by a main agent.
 
180
  ],
181
  )
182
 
183
+ result = Runner.run_sync(
184
+ starting_agent=main_agent, input=prompt, max_turns=max_turns
185
+ )
186
  logger.info(result.final_output)
187
  return result
src/surf_spot_finder/cli.py CHANGED
@@ -26,7 +26,7 @@ def find_surf_spot(
26
  api_base: Optional[str] = None,
27
  tools: Optional[list[dict]] = None,
28
  from_config: Optional[str] = None,
29
- ):
30
  """Find the best surf spot based on the given criteria.
31
 
32
  Args:
@@ -71,8 +71,10 @@ def find_surf_spot(
71
  )
72
 
73
  logger.info("Setting up tracing")
74
- tracer_provider, _ = get_tracer_provider(
75
- project_name="surf-spot-finder", json_tracer=config.json_tracer
 
 
76
  )
77
  setup_tracing(tracer_provider, config.agent_type)
78
 
@@ -88,6 +90,7 @@ def find_surf_spot(
88
  api_key_var=config.api_key_var,
89
  tools=config.tools,
90
  )
 
91
 
92
 
93
  def main():
 
26
  api_base: Optional[str] = None,
27
  tools: Optional[list[dict]] = None,
28
  from_config: Optional[str] = None,
29
+ ) -> str:
30
  """Find the best surf spot based on the given criteria.
31
 
32
  Args:
 
71
  )
72
 
73
  logger.info("Setting up tracing")
74
+ tracer_provider, tracing_path = get_tracer_provider(
75
+ project_name="surf-spot-finder",
76
+ json_tracer=config.json_tracer,
77
+ agent_type=config.agent_type,
78
  )
79
  setup_tracing(tracer_provider, config.agent_type)
80
 
 
90
  api_key_var=config.api_key_var,
91
  tools=config.tools,
92
  )
93
+ return tracing_path
94
 
95
 
96
  def main():
src/surf_spot_finder/evaluation/evaluate.py CHANGED
@@ -4,17 +4,17 @@ from textwrap import dedent
4
  from typing import Any, Dict, List, Optional
5
  from loguru import logger
6
  from fire import Fire
7
- from surf_spot_finder.agents.smolagents import run_smolagent
8
  from surf_spot_finder.config import (
9
- DEFAULT_PROMPT,
10
  Config,
11
  )
12
- from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
13
  from surf_spot_finder.evaluation.utils import (
14
- extract_hypothesis_answer,
15
  verify_checkpoints,
16
  verify_hypothesis_answer,
17
  )
 
18
  from surf_spot_finder.evaluation.test_case import TestCase
19
 
20
  logger.remove()
@@ -31,31 +31,23 @@ def run_agent(test_case: TestCase) -> str:
31
  max_driving_hours=input_data.max_driving_hours,
32
  model_id=input_data.model_id,
33
  api_key_var=input_data.api_key_var,
34
- prompt=DEFAULT_PROMPT,
35
  json_tracer=input_data.json_tracer,
36
  api_base=input_data.api_base,
37
  agent_type=input_data.agent_type,
 
38
  )
39
- # project_name is a name + uuid
40
- project_name = "surf-spot-finder"
41
-
42
- logger.info("Setting up tracing")
43
- tracer_provider, telemetry_path = get_tracer_provider(
44
- project_name=project_name, json_tracer=config.json_tracer
45
- )
46
- setup_tracing(tracer_provider, agent_type=config.agent_type)
47
- logger.info("Running agent")
48
- run_smolagent(
49
  model_id=config.model_id,
50
  api_key_var=config.api_key_var,
 
51
  api_base=config.api_base,
52
- prompt=config.prompt.format(
53
- LOCATION=config.location,
54
- MAX_DRIVING_HOURS=config.max_driving_hours,
55
- DATE=config.date,
56
- ),
57
  )
58
- return telemetry_path
59
 
60
 
61
  def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
@@ -64,8 +56,12 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
64
  telemetry: List[Dict[str, Any]] = json.loads(f.read())
65
  logger.info(f"Telemetry loaded from {telemetry_path}")
66
 
 
 
67
  # Extract the final answer from the telemetry
68
- hypothesis_answer = extract_hypothesis_answer(telemetry)
 
 
69
  logger.info(
70
  f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
71
  )
@@ -75,6 +71,7 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
75
  telemetry=telemetry,
76
  checkpoints=test_case.checkpoints,
77
  model=llm_judge,
 
78
  )
79
 
80
  hypothesis_answer_results = verify_hypothesis_answer(
@@ -110,12 +107,8 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
110
  logger.error(message)
111
  else:
112
  logger.info("<green>All checkpoints passed!</green>")
113
- logger.info(
114
- f"<green>Passed checkpoints: {len(passed_checks)}/{len(verification_results)}</green>"
115
- )
116
- logger.info(
117
- f"<red>Failed checkpoints: {len(failed_checks)}/{len(verification_results)}</red>"
118
- )
119
  logger.info("<green>=====================================</green>")
120
  logger.info(f"<green>Score: {won_points}/{won_points + missed_points}</green>")
121
  logger.info("<green>=====================================</green>")
 
4
  from typing import Any, Dict, List, Optional
5
  from loguru import logger
6
  from fire import Fire
7
+ from surf_spot_finder.cli import find_surf_spot
8
  from surf_spot_finder.config import (
 
9
  Config,
10
  )
11
+ from surf_spot_finder.prompts.shared import INPUT_PROMPT
12
  from surf_spot_finder.evaluation.utils import (
13
+ determine_agent_type,
14
  verify_checkpoints,
15
  verify_hypothesis_answer,
16
  )
17
+ from surf_spot_finder.evaluation.telemetry_utils import extract_hypothesis_answer
18
  from surf_spot_finder.evaluation.test_case import TestCase
19
 
20
  logger.remove()
 
31
  max_driving_hours=input_data.max_driving_hours,
32
  model_id=input_data.model_id,
33
  api_key_var=input_data.api_key_var,
34
+ prompt=INPUT_PROMPT,
35
  json_tracer=input_data.json_tracer,
36
  api_base=input_data.api_base,
37
  agent_type=input_data.agent_type,
38
+ tools=input_data.tools,
39
  )
40
+ return find_surf_spot(
41
+ location=config.location,
42
+ date=config.date,
43
+ max_driving_hours=config.max_driving_hours,
44
+ agent_type=config.agent_type,
 
 
 
 
 
45
  model_id=config.model_id,
46
  api_key_var=config.api_key_var,
47
+ json_tracer=config.json_tracer,
48
  api_base=config.api_base,
49
+ tools=config.tools,
 
 
 
 
50
  )
 
51
 
52
 
53
  def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
 
56
  telemetry: List[Dict[str, Any]] = json.loads(f.read())
57
  logger.info(f"Telemetry loaded from {telemetry_path}")
58
 
59
+ agent_type = determine_agent_type(telemetry)
60
+
61
  # Extract the final answer from the telemetry
62
+ hypothesis_answer = extract_hypothesis_answer(
63
+ trace=telemetry, agent_type=agent_type
64
+ )
65
  logger.info(
66
  f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
67
  )
 
71
  telemetry=telemetry,
72
  checkpoints=test_case.checkpoints,
73
  model=llm_judge,
74
+ agent_type=agent_type,
75
  )
76
 
77
  hypothesis_answer_results = verify_hypothesis_answer(
 
107
  logger.error(message)
108
  else:
109
  logger.info("<green>All checkpoints passed!</green>")
110
+ logger.info(f"<green>Passed checkpoints: {len(passed_checks)}</green>")
111
+ logger.info(f"<red>Failed checkpoints: {len(failed_checks)}</red>")
 
 
 
 
112
  logger.info("<green>=====================================</green>")
113
  logger.info(f"<green>Score: {won_points}/{won_points + missed_points}</green>")
114
  logger.info("<green>=====================================</green>")
src/surf_spot_finder/evaluation/telemetry_utils.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import json
3
+ from langchain_core.messages import BaseMessage
4
+ import re
5
+
6
+ from surf_spot_finder.agents import AgentType
7
+
8
+
9
+ def extract_hypothesis_answer(
10
+ trace: List[Dict[str, Any]], agent_type: AgentType
11
+ ) -> str:
12
+ """Extract the hypothesis agent final answer from the trace"""
13
+ for span in reversed(trace):
14
+ if agent_type == AgentType.LANGCHAIN:
15
+ if span["attributes"]["openinference.span.kind"] == "AGENT":
16
+ content = span["attributes"]["output.value"]
17
+ # If it's langchain, the actual content is a serialized langchain message that we need to extract.
18
+ message = json.loads(content)["messages"][0]
19
+ message = parse_generic_key_value_string(message)
20
+ base_message = BaseMessage(**message, type="AGENT")
21
+ print(base_message.text())
22
+ return base_message.text()
23
+ elif agent_type == AgentType.SMOLAGENTS:
24
+ if span["attributes"]["openinference.span.kind"] == "AGENT":
25
+ content = span["attributes"]["output.value"]
26
+ # If it's langchain, the actual content is a serialized langchain message that we need to extract.
27
+ return content
28
+ elif agent_type == AgentType.OPENAI:
29
+ # Looking for the final response that has the summary answer
30
+ if (
31
+ "attributes" in span
32
+ and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
33
+ ):
34
+ output_key = (
35
+ "llm.output_messages.0.message.contents.0.message_content.text"
36
+ )
37
+ if output_key in span["attributes"]:
38
+ return span["attributes"][output_key]
39
+ else:
40
+ raise ValueError(f"Unsupported agent type {agent_type}")
41
+ raise ValueError("No agent final answer found in trace")
42
+
43
+
44
+ def parse_generic_key_value_string(text):
45
+ """
46
+ Parse a string that has items of a dict with key-value pairs separated by '='.
47
+ Only splits on '=' signs, handling quoted strings properly.
48
+ I think this is to compensate for a bug in openinference? https://github.com/Arize-ai/openinference/issues/1401
49
+ """
50
+
51
+ # Pattern to match key=value pairs, handling quoted values
52
+ # This regex looks for word characters followed by = and then captures everything
53
+ # until it finds another word character followed by = or the end of the string
54
+ # Claude helped me with this one, regex is hard
55
+ pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
56
+
57
+ result = {}
58
+
59
+ matches = re.findall(pattern, text)
60
+ for key, value in matches:
61
+ # Clean up the key
62
+ key = key.strip()
63
+
64
+ # Clean up the value - remove surrounding quotes if present
65
+ if (value.startswith("'") and value.endswith("'")) or (
66
+ value.startswith('"') and value.endswith('"')
67
+ ):
68
+ value = value[1:-1]
69
+
70
+ # Store in result dictionary
71
+ result[key] = value
72
+
73
+ return result
74
+
75
+
76
+ def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
77
+ """Extract relevant telemetry evidence based on the agent type."""
78
+ # Data extraction function for each agent type
79
+ extractors = {
80
+ AgentType.SMOLAGENTS: _extract_smolagents_data,
81
+ AgentType.LANGCHAIN: _extract_langchain_data,
82
+ AgentType.OPENAI: _extract_openai_data,
83
+ }
84
+
85
+ if agent_type not in extractors:
86
+ raise ValueError(f"Unsupported agent type {agent_type}")
87
+
88
+ # Extract raw data from telemetry
89
+ calls = extractors[agent_type](telemetry)
90
+
91
+ # Format data into a consistent structure
92
+ return _format_evidence(calls, agent_type)
93
+
94
+
95
+ def _extract_smolagents_data(telemetry: List[Dict[str, Any]]) -> List[Dict]:
96
+ """Extract LLM calls and tool calls from SmoL Agents telemetry."""
97
+ calls = []
98
+
99
+ for span in telemetry:
100
+ # Skip spans without attributes
101
+ if "attributes" not in span:
102
+ continue
103
+
104
+ attributes = span["attributes"]
105
+
106
+ # Extract tool information
107
+ if "tool.name" in attributes or span.get("name", "").startswith("SimpleTool"):
108
+ tool_info = {
109
+ "tool_name": attributes.get(
110
+ "tool.name", span.get("name", "Unknown tool")
111
+ ),
112
+ "status": "success"
113
+ if span.get("status", {}).get("status_code") == "OK"
114
+ else "error",
115
+ "error": span.get("status", {}).get("description", None),
116
+ }
117
+
118
+ # Extract input if available
119
+ if "input.value" in attributes:
120
+ try:
121
+ input_value = json.loads(attributes["input.value"])
122
+ if "kwargs" in input_value:
123
+ # For SmoLAgents, the actual input is often in the kwargs field
124
+ tool_info["input"] = input_value["kwargs"]
125
+ else:
126
+ tool_info["input"] = input_value
127
+ except (json.JSONDecodeError, TypeError):
128
+ tool_info["input"] = attributes["input.value"]
129
+
130
+ # Extract output if available
131
+ if "output.value" in attributes:
132
+ try:
133
+ # Try to parse JSON output
134
+ output_value = (
135
+ json.loads(attributes["output.value"])
136
+ if isinstance(attributes["output.value"], str)
137
+ else attributes["output.value"]
138
+ )
139
+ tool_info["output"] = output_value
140
+ except (json.JSONDecodeError, TypeError):
141
+ tool_info["output"] = attributes["output.value"]
142
+ else:
143
+ tool_info["output"] = "No output found"
144
+
145
+ calls.append(tool_info)
146
+
147
+ # Extract LLM calls to see reasoning
148
+ elif "LiteLLMModel.__call__" in span.get("name", ""):
149
+ # The LLM output may be in different places depending on the implementation
150
+ output_content = None
151
+
152
+ # Try to get the output from the llm.output_messages.0.message.content attribute
153
+ if "llm.output_messages.0.message.content" in attributes:
154
+ output_content = attributes["llm.output_messages.0.message.content"]
155
+
156
+ # Or try to parse it from the output.value as JSON
157
+ elif "output.value" in attributes:
158
+ try:
159
+ output_value = json.loads(attributes["output.value"])
160
+ if "content" in output_value:
161
+ output_content = output_value["content"]
162
+ except (json.JSONDecodeError, TypeError):
163
+ pass
164
+
165
+ if output_content:
166
+ calls.append(
167
+ {
168
+ "model": attributes.get("llm.model_name", "Unknown model"),
169
+ "output": output_content,
170
+ "type": "reasoning",
171
+ }
172
+ )
173
+
174
+ return calls
175
+
176
+
177
+ def _extract_langchain_data(telemetry: List[Dict[str, Any]]) -> List:
178
+ """Extract LLM calls and tool calls from LangChain telemetry."""
179
+ calls = []
180
+
181
+ for span in telemetry:
182
+ if "attributes" not in span:
183
+ continue
184
+
185
+ attributes = span.get("attributes", {})
186
+ span_kind = attributes.get("openinference.span.kind", "")
187
+
188
+ # Collect LLM calls
189
+ if span_kind == "LLM" and "llm.output_messages.0.message.content" in attributes:
190
+ llm_info = {
191
+ "model": attributes.get("llm.model_name", "Unknown model"),
192
+ "input": attributes.get("llm.input_messages.0.message.content", ""),
193
+ "output": attributes.get("llm.output_messages.0.message.content", ""),
194
+ "type": "reasoning",
195
+ }
196
+ calls.append(llm_info)
197
+
198
+ # Try to find tool calls
199
+ if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
200
+ tool_info = {
201
+ "tool_name": attributes.get(
202
+ "tool.name", span.get("name", "Unknown tool")
203
+ ),
204
+ "status": "success"
205
+ if span.get("status", {}).get("status_code") == "OK"
206
+ else "error",
207
+ "error": span.get("status", {}).get("description", None),
208
+ }
209
+
210
+ if "input.value" in attributes:
211
+ try:
212
+ input_value = json.loads(attributes["input.value"])
213
+ tool_info["input"] = input_value
214
+ except Exception:
215
+ tool_info["input"] = attributes["input.value"]
216
+
217
+ if "output.value" in attributes:
218
+ tool_info["output"] = parse_generic_key_value_string(
219
+ json.loads(attributes["output.value"])["output"]
220
+ )["content"]
221
+
222
+ calls.append(tool_info)
223
+
224
+ return calls
225
+
226
+
227
+ def _extract_openai_data(telemetry: List[Dict[str, Any]]) -> list:
228
+ """Extract LLM calls and tool calls from OpenAI telemetry."""
229
+ calls = []
230
+
231
+ for span in telemetry:
232
+ if "attributes" not in span:
233
+ continue
234
+
235
+ attributes = span.get("attributes", {})
236
+ span_kind = attributes.get("openinference.span.kind", "")
237
+
238
+ # Collect LLM interactions - look for direct message content first
239
+ if span_kind == "LLM":
240
+ # Initialize the LLM info dictionary
241
+ span_info = {}
242
+
243
+ # Try to get input message
244
+ input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
245
+ if input_key in attributes:
246
+ span_info["input"] = attributes[input_key]
247
+
248
+ # Try to get output message directly
249
+ output_content = None
250
+ # Try in multiple possible locations
251
+ for key in [
252
+ "llm.output_messages.0.message.content",
253
+ "llm.output_messages.0.message.contents.0.message_content.text",
254
+ ]:
255
+ if key in attributes:
256
+ output_content = attributes[key]
257
+ break
258
+
259
+ # If we found direct output content, use it
260
+ if output_content:
261
+ span_info["output"] = output_content
262
+ calls.append(span_info)
263
+ elif span_kind == "TOOL":
264
+ tool_name = attributes.get("tool.name", "Unknown tool")
265
+ tool_output = attributes.get("output.value", "")
266
+
267
+ span_info = {
268
+ "tool_name": tool_name,
269
+ "input": attributes.get("input.value", ""),
270
+ "output": tool_output,
271
+ "status": span.get("status", {}).get("status_code"),
272
+ }
273
+ span_info["input"] = json.loads(span_info["input"])
274
+
275
+ calls.append(span_info)
276
+
277
+ return calls
278
+
279
+
280
+ def _format_evidence(calls: List[Dict], agent_type: AgentType) -> str:
281
+ """Format extracted data into a standardized output format."""
282
+ evidence = f"## {agent_type.name} Agent Execution\n\n"
283
+
284
+ for idx, call in enumerate(calls, start=1):
285
+ evidence += f"### Call {idx}\n"
286
+
287
+ # Truncate any values that are too long
288
+ max_length = 400
289
+ call = {
290
+ k: (
291
+ v[:max_length] + "..."
292
+ if isinstance(v, str) and len(v) > max_length
293
+ else v
294
+ )
295
+ for k, v in call.items()
296
+ }
297
+
298
+ # Use ensure_ascii=False to prevent escaping Unicode characters
299
+ evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
300
+
301
+ return evidence
src/surf_spot_finder/evaluation/test_case.py CHANGED
@@ -15,6 +15,7 @@ class InputModel(BaseModel):
15
  json_tracer: bool
16
  api_base: Optional[str] = None
17
  agent_type: str
 
18
 
19
 
20
  class CheckpointCriteria(BaseModel):
 
15
  json_tracer: bool
16
  api_base: Optional[str] = None
17
  agent_type: str
18
+ tools: Optional[List[str]] = None
19
 
20
 
21
  class CheckpointCriteria(BaseModel):
src/surf_spot_finder/evaluation/test_cases/alpha.yaml CHANGED
@@ -1,30 +1,65 @@
1
  # Test case for surf spot finder
 
 
 
2
  input:
3
  location: "Vigo"
4
- date: "2025-03-15 22:00"
5
  max_driving_hours: 3
6
- model_id: "openai/o3-mini"
7
  api_key_var: "OPENAI_API_KEY"
8
  json_tracer: true
9
  api_base: null
10
- agent_type: "smolagents"
11
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ground_truth:
13
  - name: "Surf location"
14
  points: 5
15
- value: "Playa de Patos"
16
- - name: "Water temperature"
17
- points: 1
18
- value: "about 14°C +-5°C"
19
- - name: "Wave height"
20
- points: 1
21
- value: "about 1 meter"
22
 
23
  # Base checkpoints for agent behavior
24
  # These evaluators for these checkpoints
25
  # will not consider the hypothesis answer or final answer in their decision making
26
  checkpoints:
27
  - points: 1
28
- criteria: "Check if the agent consulted DuckDuckGoSearchTool for locations near Vigo."
 
 
 
 
 
 
 
 
29
  - points: 1
30
- criteria: "Check if the agent fetched a website for forecasting, not relying on text from a DuckDuckGo search."
 
1
  # Test case for surf spot finder
2
+
3
+ # You only need this input data if you want to run the test case, if you pass in a path to a telemetry file this
4
+ # is ignored
5
  input:
6
  location: "Vigo"
7
+ date: "2025-03-27 22:00"
8
  max_driving_hours: 3
 
9
  api_key_var: "OPENAI_API_KEY"
10
  json_tracer: true
11
  api_base: null
12
+ # model_id: "openai/o1"
13
+ # agent_type: "smolagents"
14
+ # tools:
15
+ # - "surf_spot_finder.tools.driving_hours_to_meters"
16
+ # - "surf_spot_finder.tools.get_area_lat_lon"
17
+ # - "surf_spot_finder.tools.get_surfing_spots"
18
+ # - "surf_spot_finder.tools.get_wave_forecast"
19
+ # - "surf_spot_finder.tools.get_wind_forecast"
20
+ # - "surf_spot_finder.tools.search_web"
21
+ # - "surf_spot_finder.tools.visit_webpage"
22
+ # - "smolagents.PythonInterpreterTool"
23
+ # - "smolagents.FinalAnswerTool"
24
+ agent_type: langchain
25
+ model_id: o1
26
+ tools:
27
+ - "surf_spot_finder.tools.driving_hours_to_meters"
28
+ - "surf_spot_finder.tools.get_area_lat_lon"
29
+ - "surf_spot_finder.tools.get_surfing_spots"
30
+ - "surf_spot_finder.tools.get_wave_forecast"
31
+ - "surf_spot_finder.tools.get_wind_forecast"
32
+ - "surf_spot_finder.tools.search_web"
33
+ - "surf_spot_finder.tools.visit_webpage"
34
+ # model_id: o3-mini
35
+ # agent_type: openai
36
+ # tools:
37
+ # - "surf_spot_finder.tools.driving_hours_to_meters"
38
+ # - "surf_spot_finder.tools.get_area_lat_lon"
39
+ # - "surf_spot_finder.tools.get_surfing_spots"
40
+ # - "surf_spot_finder.tools.get_wave_forecast"
41
+ # - "surf_spot_finder.tools.get_wind_forecast"
42
+ # - "surf_spot_finder.tools.search_web"
43
+ # - "surf_spot_finder.tools.show_plan"
44
+ # - "surf_spot_finder.tools.visit_webpage"
45
  ground_truth:
46
  - name: "Surf location"
47
  points: 5
48
+ value: "Playa de Samil"
 
 
 
 
 
 
49
 
50
  # Base checkpoints for agent behavior
51
  # These evaluators for these checkpoints
52
  # will not consider the hypothesis answer or final answer in their decision making
53
  checkpoints:
54
  - points: 1
55
+ criteria: "Check if the agent did a web search for nearby surf locations."
56
+ - points: 1
57
+ criteria: "Check if the agent used the get_surfing_spots tool and it succeeded"
58
+ - points: 1
59
+ criteria: "Check if the agent used the get_wave_forecast tool and it succeeded"
60
+ - points: 1
61
+ criteria: "Check if the agent used the get_wind_forecast tool and it succeeded"
62
+ - points: 1
63
+ criteria: "Check if the agent used the get_area_lat_lon tool and it succeeded"
64
  - points: 1
65
+ criteria: "Check if the final answer contains any description about the weather at the chosen location"
src/surf_spot_finder/evaluation/utils.py CHANGED
@@ -4,10 +4,36 @@ import re
4
 
5
  from litellm import completion
6
  from textwrap import dedent
 
7
 
8
  from pydantic import BaseModel, ConfigDict
 
9
  from surf_spot_finder.evaluation.test_case import CheckpointCriteria
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class EvaluationResult(BaseModel):
13
  """Represents the result of evaluating a criterion"""
@@ -19,15 +45,6 @@ class EvaluationResult(BaseModel):
19
  points: int
20
 
21
 
22
- def extract_hypothesis_answer(telemetry: List[Dict[str, Any]]) -> str | None:
23
- """Extract the hypothesis agent final answer from the telemetry data"""
24
- for span in reversed(telemetry):
25
- if span.get("attributes", {}).get("openinference.span.kind") == "AGENT":
26
- hypo = span.get("attributes", {}).get("output.value")
27
- return hypo
28
- raise ValueError("Final answer not found in telemetry")
29
-
30
-
31
  def evaluate_criterion(
32
  criteria: str,
33
  model: str,
@@ -109,6 +126,7 @@ def verify_checkpoints(
109
  telemetry: List[Dict[str, Any]],
110
  checkpoints: List[CheckpointCriteria],
111
  model: str,
 
112
  ) -> List[EvaluationResult]:
113
  """Verify each checkpoint against the telemetry data using LLM
114
  These checkpoints do not take the ground truth or hyupothesis
@@ -117,9 +135,10 @@ def verify_checkpoints(
117
  """
118
  results = []
119
 
 
 
120
  for checkpoint in checkpoints:
121
  criteria = checkpoint.criteria
122
- evidence = extract_relevant_evidence(telemetry, criteria)
123
 
124
  evaluation = evaluate_criterion(
125
  criteria=criteria,
@@ -156,62 +175,3 @@ def verify_hypothesis_answer(
156
  results.append(evaluation)
157
 
158
  return results
159
-
160
-
161
- def extract_relevant_evidence(telemetry: List[Dict[str, Any]], criteria: str) -> str:
162
- """Extract relevant telemetry evidence based on the checkpoint criteria
163
- TODO this is not a very robust implementation, since it requires knowledge about which tools have been
164
- implemented. We should abstract this so that it can dynamically figure out what tools may have been used
165
- and check for them appropriately. I understand that this tool should probably have some better way of abstracting
166
- relevant information from the opentelemetry spans."""
167
- evidence = ""
168
-
169
- # Look for evidence of tool usage
170
- if "DuckDuckGoSearchTool" in criteria:
171
- search_spans = [
172
- span for span in telemetry if span.get("name") == "DuckDuckGoSearchTool"
173
- ]
174
- evidence += f"Search tool was used {len(search_spans)} times.\n"
175
- for i, span in enumerate(search_spans): # Limit to first 3 searches
176
- if "attributes" in span and "input.value" in span["attributes"]:
177
- try:
178
- input_value = json.loads(span["attributes"]["input.value"])
179
- if "kwargs" in input_value and "query" in input_value["kwargs"]:
180
- evidence += (
181
- f"Search query {i + 1}: {input_value['kwargs']['query']}\n"
182
- )
183
- except (json.JSONDecodeError, TypeError):
184
- pass
185
-
186
- # Look for evidence of website fetching
187
- if "fetched a website" in criteria:
188
- fetch_spans = [
189
- span
190
- for span in telemetry
191
- if span.get("attributes", {}).get("tool.name") == "fetch"
192
- ]
193
- evidence += f"Website fetch tool was used {len(fetch_spans)} times.\n"
194
- for i, span in enumerate(fetch_spans): # Limit to first 3 fetches
195
- if "attributes" in span and "input.value" in span["attributes"]:
196
- try:
197
- input_value = json.loads(span["attributes"]["input.value"])
198
- if "kwargs" in input_value and "url" in input_value["kwargs"]:
199
- evidence += (
200
- f"Fetched URL {i + 1}: {input_value['kwargs']['url']}\n"
201
- )
202
- except (json.JSONDecodeError, TypeError):
203
- pass
204
-
205
- # Add general evidence about all tool calls
206
- tool_calls = {}
207
- for span in telemetry:
208
- if "name" in span and span["name"] not in tool_calls:
209
- tool_calls[span["name"]] = 1
210
- elif "name" in span:
211
- tool_calls[span["name"]] += 1
212
-
213
- evidence += "\nTool calls summary:\n"
214
- for tool, count in tool_calls.items():
215
- evidence += f"- {tool}: {count} call(s)\n"
216
-
217
- return evidence
 
4
 
5
  from litellm import completion
6
  from textwrap import dedent
7
+ from loguru import logger
8
 
9
  from pydantic import BaseModel, ConfigDict
10
+ from surf_spot_finder.evaluation.telemetry_utils import extract_evidence
11
  from surf_spot_finder.evaluation.test_case import CheckpointCriteria
12
 
13
+ from surf_spot_finder.agents import AgentType
14
+
15
+
16
+ def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
17
+ """Determine the agent type based on the trace.
18
+ These are not really stable ways to find it, because we're waiting on some
19
+ reliable method for determining the agent type. This is a temporary solution.
20
+ """
21
+ for span in trace:
22
+ if "langchain" in span.get("attributes", {}).get("input.value", ""):
23
+ logger.info("Agent type is LANGCHAIN")
24
+ return AgentType.LANGCHAIN
25
+ if span.get("attributes", {}).get("smolagents.max_steps"):
26
+ logger.info("Agent type is SMOLAGENTS")
27
+ return AgentType.SMOLAGENTS
28
+ # This is extremely fragile but there currently isn't
29
+ # any specific key to indicate the agent type
30
+ if span.get("name") == "response":
31
+ logger.info("Agent type is OPENAI")
32
+ return AgentType.OPENAI
33
+ raise ValueError(
34
+ "Could not determine agent type from trace, or agent type not supported"
35
+ )
36
+
37
 
38
  class EvaluationResult(BaseModel):
39
  """Represents the result of evaluating a criterion"""
 
45
  points: int
46
 
47
 
 
 
 
 
 
 
 
 
 
48
  def evaluate_criterion(
49
  criteria: str,
50
  model: str,
 
126
  telemetry: List[Dict[str, Any]],
127
  checkpoints: List[CheckpointCriteria],
128
  model: str,
129
+ agent_type: AgentType,
130
  ) -> List[EvaluationResult]:
131
  """Verify each checkpoint against the telemetry data using LLM
132
  These checkpoints do not take the ground truth or hyupothesis
 
135
  """
136
  results = []
137
 
138
+ evidence = extract_evidence(telemetry, agent_type)
139
+ print(evidence)
140
  for checkpoint in checkpoints:
141
  criteria = checkpoint.criteria
 
142
 
143
  evaluation = evaluate_criterion(
144
  criteria=criteria,
 
175
  results.append(evaluation)
176
 
177
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/surf_spot_finder/tracing.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  import json
3
  from datetime import datetime
4
-
5
- from opentelemetry import trace
6
  from opentelemetry.sdk.trace import TracerProvider
7
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
8
- from opentelemetry.sdk.trace.export import SpanExporter
 
9
 
10
 
11
  class JsonFileSpanExporter(SpanExporter):
@@ -44,7 +43,10 @@ class JsonFileSpanExporter(SpanExporter):
44
 
45
 
46
  def get_tracer_provider(
47
- project_name: str, json_tracer: bool, output_dir: str = "telemetry_output"
 
 
 
48
  ) -> tuple[TracerProvider, str | None]:
49
  """
50
  Create a tracer_provider based on the selected mode.
@@ -52,6 +54,7 @@ def get_tracer_provider(
52
  Args:
53
  project_name: Name of the project for tracing
54
  json_tracer: Whether to use the custom JSON file exporter (True) or Phoenix (False)
 
55
  output_dir: The directory where the telemetry output will be stored.
56
  Only used if `json_tracer=True`.
57
  Defaults to "telemetry_output".
@@ -66,9 +69,7 @@ def get_tracer_provider(
66
  timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
67
 
68
  tracer_provider = TracerProvider()
69
- trace.set_tracer_provider(tracer_provider)
70
-
71
- file_name = f"{output_dir}/{project_name}-{timestamp}.json"
72
  json_file_exporter = JsonFileSpanExporter(file_name=file_name)
73
  span_processor = SimpleSpanProcessor(json_file_exporter)
74
  tracer_provider.add_span_processor(span_processor)
@@ -97,14 +98,17 @@ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
97
  validate_agent_type(agent_type)
98
 
99
  if "openai" in agent_type:
100
- from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor
101
-
102
- OpenAIAgentsInstrumentor().instrument(tracer_provider=tracer_provider)
103
  elif agent_type == "smolagents":
104
- from openinference.instrumentation.smolagents import SmolagentsInstrumentor
105
-
106
- SmolagentsInstrumentor().instrument(tracer_provider=tracer_provider)
107
  elif agent_type == "langchain":
108
- from openinference.instrumentation.langchain import LangChainInstrumentor
109
-
110
- LangChainInstrumentor().instrument(tracer_provider=tracer_provider)
 
 
 
 
1
  import os
2
  import json
3
  from datetime import datetime
 
 
4
  from opentelemetry.sdk.trace import TracerProvider
5
+ from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter
6
+
7
+ from surf_spot_finder.agents import AgentType
8
 
9
 
10
  class JsonFileSpanExporter(SpanExporter):
 
43
 
44
 
45
  def get_tracer_provider(
46
+ project_name: str,
47
+ json_tracer: bool,
48
+ agent_type: AgentType,
49
+ output_dir: str = "telemetry_output",
50
  ) -> tuple[TracerProvider, str | None]:
51
  """
52
  Create a tracer_provider based on the selected mode.
 
54
  Args:
55
  project_name: Name of the project for tracing
56
  json_tracer: Whether to use the custom JSON file exporter (True) or Phoenix (False)
57
+ agent_type: The type of agent being used.
58
  output_dir: The directory where the telemetry output will be stored.
59
  Only used if `json_tracer=True`.
60
  Defaults to "telemetry_output".
 
69
  timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
70
 
71
  tracer_provider = TracerProvider()
72
+ file_name = f"{output_dir}/{agent_type}-{project_name}-{timestamp}.json"
 
 
73
  json_file_exporter = JsonFileSpanExporter(file_name=file_name)
74
  span_processor = SimpleSpanProcessor(json_file_exporter)
75
  tracer_provider.add_span_processor(span_processor)
 
98
  validate_agent_type(agent_type)
99
 
100
  if "openai" in agent_type:
101
+ from openinference.instrumentation.openai_agents import (
102
+ OpenAIAgentsInstrumentor as Instrumentor,
103
+ )
104
  elif agent_type == "smolagents":
105
+ from openinference.instrumentation.smolagents import (
106
+ SmolagentsInstrumentor as Instrumentor,
107
+ )
108
  elif agent_type == "langchain":
109
+ from openinference.instrumentation.langchain import (
110
+ LangChainInstrumentor as Instrumentor,
111
+ )
112
+ else:
113
+ raise ValueError(f"Unsupported agent type: {agent_type}")
114
+ Instrumentor().instrument(tracer_provider=tracer_provider)
tests/unit/test_unit_tracing.py CHANGED
@@ -2,31 +2,27 @@ from unittest.mock import patch, MagicMock
2
 
3
  import pytest
4
 
 
5
  from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
6
 
7
 
8
  @pytest.mark.parametrize("json_tracer", [True, False])
9
  def test_get_tracer_provider(tmp_path, json_tracer):
10
- mock_trace = MagicMock()
11
  mock_tracer_provider = MagicMock()
12
  mock_register = MagicMock()
13
 
14
  with (
15
- patch("surf_spot_finder.tracing.trace", mock_trace),
16
  patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
17
  patch("phoenix.otel.register", mock_register),
18
  ):
19
  get_tracer_provider(
20
  project_name="test_project",
21
  json_tracer=json_tracer,
 
22
  output_dir=tmp_path / "telemetry",
23
  )
24
  assert (tmp_path / "telemetry").exists() == json_tracer
25
- if json_tracer:
26
- mock_trace.set_tracer_provider.assert_called_once_with(
27
- mock_tracer_provider.return_value
28
- )
29
- else:
30
  mock_register.assert_called_once_with(
31
  project_name="test_project", set_global_tracer_provider=True
32
  )
 
2
 
3
  import pytest
4
 
5
+ from surf_spot_finder.agents import AgentType
6
  from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
7
 
8
 
9
  @pytest.mark.parametrize("json_tracer", [True, False])
10
  def test_get_tracer_provider(tmp_path, json_tracer):
 
11
  mock_tracer_provider = MagicMock()
12
  mock_register = MagicMock()
13
 
14
  with (
 
15
  patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
16
  patch("phoenix.otel.register", mock_register),
17
  ):
18
  get_tracer_provider(
19
  project_name="test_project",
20
  json_tracer=json_tracer,
21
+ agent_type=AgentType.SMOLAGENTS,
22
  output_dir=tmp_path / "telemetry",
23
  )
24
  assert (tmp_path / "telemetry").exists() == json_tracer
25
+ if not json_tracer:
 
 
 
 
26
  mock_register.assert_called_once_with(
27
  project_name="test_project", set_global_tracer_provider=True
28
  )