Nathan Brake commited on
Commit
fea07c2
·
unverified ·
1 Parent(s): ef766f7

Split telemetry processing into cleaner classes, support ollama (#31)

Browse files
.pylintrc ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [MESSAGES CONTROL]
2
+ disable=C0415
examples/langchain_single_agent_vertical.yaml CHANGED
@@ -5,6 +5,7 @@ input:
5
  # input_prompt_template:
6
  agent:
7
  model_id: o3-mini
 
8
  agent_type: langchain
9
  tools:
10
  - "surf_spot_finder.tools.driving_hours_to_meters"
 
5
  # input_prompt_template:
6
  agent:
7
  model_id: o3-mini
8
+ # model_id: ollama/llama3.2:3b
9
  agent_type: langchain
10
  tools:
11
  - "surf_spot_finder.tools.driving_hours_to_meters"
pyproject.toml CHANGED
@@ -21,6 +21,7 @@ langchain = [
21
  "langchain",
22
  "langgraph",
23
  "langchain-openai>=0.3.9",
 
24
  "openinference-instrumentation-langchain"
25
  ]
26
  smolagents = [
 
21
  "langchain",
22
  "langgraph",
23
  "langchain-openai>=0.3.9",
24
+ "langchain-ollama>=0.3.0",
25
  "openinference-instrumentation-langchain"
26
  ]
27
  smolagents = [
src/surf_spot_finder/agents/langchain.py CHANGED
@@ -51,8 +51,11 @@ def run_lanchain_agent(
51
  if not isinstance(imported_tool, BaseTool):
52
  imported_tool = tool(imported_tool)
53
  imported_tools.append((imported_tool))
54
-
55
- model = init_chat_model(model_id)
 
 
 
56
  agent = create_react_agent(
57
  model=model, tools=imported_tools, checkpointer=MemorySaver()
58
  )
 
51
  if not isinstance(imported_tool, BaseTool):
52
  imported_tool = tool(imported_tool)
53
  imported_tools.append((imported_tool))
54
+ if "/" in model_id:
55
+ model_provider, model_id = model_id.split("/")
56
+ model = init_chat_model(model_id, model_provider=model_provider)
57
+ else:
58
+ model = init_chat_model(model_id)
59
  agent = create_react_agent(
60
  model=model, tools=imported_tools, checkpointer=MemorySaver()
61
  )
src/surf_spot_finder/evaluation/evaluate.py CHANGED
@@ -8,13 +8,11 @@ from surf_spot_finder.cli import find_surf_spot
8
  from surf_spot_finder.config import (
9
  Config,
10
  )
11
- from surf_spot_finder.prompts.shared import INPUT_PROMPT
12
  from surf_spot_finder.evaluation.utils import (
13
- determine_agent_type,
14
  verify_checkpoints,
15
  verify_hypothesis_answer,
16
  )
17
- from surf_spot_finder.evaluation.telemetry_utils import extract_hypothesis_answer
18
  from surf_spot_finder.evaluation.test_case import TestCase
19
 
20
  logger.remove()
@@ -22,22 +20,15 @@ logger = logger.opt(ansi=True)
22
  logger.add(sys.stdout, colorize=True, format="{message}")
23
 
24
 
25
- def run_agent(test_case: TestCase) -> str:
26
  input_data = test_case.input
27
- agent_config = test_case.agent
28
  logger.info("Loading config")
29
- config = Config(
30
- location=input_data.location,
31
- date=input_data.date,
32
- max_driving_hours=input_data.max_driving_hours,
33
- model_id=agent_config.model_id,
34
- api_key_var=agent_config.api_key_var,
35
- prompt=INPUT_PROMPT,
36
- json_tracer=input_data.json_tracer,
37
- api_base=agent_config.api_base,
38
- agent_type=agent_config.agent_type,
39
- tools=agent_config.tools,
40
- )
41
  return find_surf_spot(
42
  location=config.location,
43
  date=config.date,
@@ -57,12 +48,11 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
57
  telemetry: List[Dict[str, Any]] = json.loads(f.read())
58
  logger.info(f"Telemetry loaded from {telemetry_path}")
59
 
60
- agent_type = determine_agent_type(telemetry)
61
 
62
  # Extract the final answer from the telemetry
63
- hypothesis_answer = extract_hypothesis_answer(
64
- trace=telemetry, agent_type=agent_type
65
- )
66
  logger.info(
67
  f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
68
  )
@@ -72,7 +62,7 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
72
  telemetry=telemetry,
73
  checkpoints=test_case.checkpoints,
74
  model=llm_judge,
75
- agent_type=agent_type,
76
  )
77
 
78
  hypothesis_answer_results = verify_hypothesis_answer(
@@ -116,7 +106,9 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
116
 
117
 
118
  def evaluate(
119
- test_case_path: str, agent_config_path: str, telemetry_path: Optional[str] = None
 
 
120
  ) -> None:
121
  """
122
  Evaluate agent performance using either a provided telemetry file or by running the agent.
@@ -125,15 +117,16 @@ def evaluate(
125
  telemetry_path: Optional path to an existing telemetry file. If not provided,
126
  the agent will be run to generate one.
127
  """
128
- test_case = TestCase.from_yaml(
129
- test_case_path=test_case_path, agent_config_path=agent_config_path
130
- )
131
 
132
  if telemetry_path is None:
133
  logger.info(
134
  "No telemetry path provided. Running agent to generate telemetry..."
135
  )
136
- telemetry_path = run_agent(test_case)
 
 
 
137
  else:
138
  logger.info(f"Using provided telemetry file: {telemetry_path}")
139
  logger.info(
 
8
  from surf_spot_finder.config import (
9
  Config,
10
  )
11
+ from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
12
  from surf_spot_finder.evaluation.utils import (
 
13
  verify_checkpoints,
14
  verify_hypothesis_answer,
15
  )
 
16
  from surf_spot_finder.evaluation.test_case import TestCase
17
 
18
  logger.remove()
 
20
  logger.add(sys.stdout, colorize=True, format="{message}")
21
 
22
 
23
+ def run_agent(test_case: TestCase, agent_config_path: str) -> str:
24
  input_data = test_case.input
25
+
26
  logger.info("Loading config")
27
+ config = Config.from_yaml(agent_config_path)
28
+ config.location = input_data.location
29
+ config.date = input_data.date
30
+ config.max_driving_hours = input_data.max_driving_hours
31
+ config.json_tracer = input_data.json_tracer
 
 
 
 
 
 
 
32
  return find_surf_spot(
33
  location=config.location,
34
  date=config.date,
 
48
  telemetry: List[Dict[str, Any]] = json.loads(f.read())
49
  logger.info(f"Telemetry loaded from {telemetry_path}")
50
 
51
+ agent_type = TelemetryProcessor.determine_agent_type(telemetry)
52
 
53
  # Extract the final answer from the telemetry
54
+ processor = TelemetryProcessor.create(agent_type)
55
+ hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
 
56
  logger.info(
57
  f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
58
  )
 
62
  telemetry=telemetry,
63
  checkpoints=test_case.checkpoints,
64
  model=llm_judge,
65
+ processor=processor,
66
  )
67
 
68
  hypothesis_answer_results = verify_hypothesis_answer(
 
106
 
107
 
108
  def evaluate(
109
+ test_case_path: str,
110
+ agent_config_path: str = None,
111
+ telemetry_path: Optional[str] = None,
112
  ) -> None:
113
  """
114
  Evaluate agent performance using either a provided telemetry file or by running the agent.
 
117
  telemetry_path: Optional path to an existing telemetry file. If not provided,
118
  the agent will be run to generate one.
119
  """
120
+ test_case = TestCase.from_yaml(test_case_path=test_case_path)
 
 
121
 
122
  if telemetry_path is None:
123
  logger.info(
124
  "No telemetry path provided. Running agent to generate telemetry..."
125
  )
126
+ assert (
127
+ agent_config_path is not None
128
+ ), "Agent config path must be provided if running agent"
129
+ telemetry_path = run_agent(test_case, agent_config_path)
130
  else:
131
  logger.info(f"Using provided telemetry file: {telemetry_path}")
132
  logger.info(
src/surf_spot_finder/evaluation/telemetry/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .telemetry import TelemetryProcessor
2
+
3
+ __all__ = ["TelemetryProcessor"]
src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import json
3
+ from langchain_core.messages import BaseMessage
4
+
5
+ from surf_spot_finder.agents import AgentType
6
+ from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
7
+
8
+
9
+ class LangchainTelemetryProcessor(TelemetryProcessor):
10
+ """Processor for Langchain agent telemetry data."""
11
+
12
+ def _get_agent_type(self) -> AgentType:
13
+ return AgentType.LANGCHAIN
14
+
15
+ def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
16
+ for span in reversed(trace):
17
+ if span["attributes"]["openinference.span.kind"] == "AGENT":
18
+ content = span["attributes"]["output.value"]
19
+ # Extract content from serialized langchain message
20
+ message = json.loads(content)["messages"][0]
21
+ message = self.parse_generic_key_value_string(message)
22
+ base_message = BaseMessage(content=message["content"], type="AGENT")
23
+ print(base_message.text())
24
+ return base_message.text()
25
+
26
+ raise ValueError("No agent final answer found in trace")
27
+
28
+ def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
29
+ """Extract LLM calls and tool calls from LangChain telemetry."""
30
+ calls = []
31
+
32
+ for span in telemetry:
33
+ if "attributes" not in span:
34
+ continue
35
+
36
+ attributes = span.get("attributes", {})
37
+ span_kind = attributes.get("openinference.span.kind", "")
38
+
39
+ # Collect LLM calls
40
+ if (
41
+ span_kind == "LLM"
42
+ and "llm.output_messages.0.message.content" in attributes
43
+ ):
44
+ llm_info = {
45
+ "model": attributes.get("llm.model_name", "Unknown model"),
46
+ "input": attributes.get("llm.input_messages.0.message.content", ""),
47
+ "output": attributes.get(
48
+ "llm.output_messages.0.message.content", ""
49
+ ),
50
+ "type": "reasoning",
51
+ }
52
+ calls.append(llm_info)
53
+
54
+ # Try to find tool calls
55
+ if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
56
+ tool_info = {
57
+ "tool_name": attributes.get(
58
+ "tool.name", span.get("name", "Unknown tool")
59
+ ),
60
+ "status": "success"
61
+ if span.get("status", {}).get("status_code") == "OK"
62
+ else "error",
63
+ "error": span.get("status", {}).get("description", None),
64
+ }
65
+
66
+ if "input.value" in attributes:
67
+ try:
68
+ input_value = json.loads(attributes["input.value"])
69
+ tool_info["input"] = input_value
70
+ except Exception:
71
+ tool_info["input"] = attributes["input.value"]
72
+
73
+ if "output.value" in attributes:
74
+ tool_info["output"] = self.parse_generic_key_value_string(
75
+ json.loads(attributes["output.value"])["output"]
76
+ )["content"]
77
+
78
+ calls.append(tool_info)
79
+
80
+ return calls
src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import json
3
+
4
+ from surf_spot_finder.agents import AgentType
5
+ from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
6
+
7
+
8
+ class OpenAITelemetryProcessor(TelemetryProcessor):
9
+ """Processor for OpenAI agent telemetry data."""
10
+
11
+ def _get_agent_type(self) -> AgentType:
12
+ return AgentType.OPENAI
13
+
14
+ def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
15
+ for span in reversed(trace):
16
+ # Looking for the final response that has the summary answer
17
+ if (
18
+ "attributes" in span
19
+ and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
20
+ ):
21
+ output_key = (
22
+ "llm.output_messages.0.message.contents.0.message_content.text"
23
+ )
24
+ if output_key in span["attributes"]:
25
+ return span["attributes"][output_key]
26
+
27
+ raise ValueError("No agent final answer found in trace")
28
+
29
+ def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> list:
30
+ """Extract LLM calls and tool calls from OpenAI telemetry."""
31
+ calls = []
32
+
33
+ for span in telemetry:
34
+ if "attributes" not in span:
35
+ continue
36
+
37
+ attributes = span.get("attributes", {})
38
+ span_kind = attributes.get("openinference.span.kind", "")
39
+
40
+ # Collect LLM interactions - look for direct message content first
41
+ if span_kind == "LLM":
42
+ # Initialize the LLM info dictionary
43
+ span_info = {}
44
+
45
+ # Try to get input message
46
+ input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
47
+ if input_key in attributes:
48
+ span_info["input"] = attributes[input_key]
49
+
50
+ # Try to get output message directly
51
+ output_content = None
52
+ # Try in multiple possible locations
53
+ for key in [
54
+ "llm.output_messages.0.message.content",
55
+ "llm.output_messages.0.message.contents.0.message_content.text",
56
+ ]:
57
+ if key in attributes:
58
+ output_content = attributes[key]
59
+ break
60
+
61
+ # If we found direct output content, use it
62
+ if output_content:
63
+ span_info["output"] = output_content
64
+ calls.append(span_info)
65
+ elif span_kind == "TOOL":
66
+ tool_name = attributes.get("tool.name", "Unknown tool")
67
+ tool_output = attributes.get("output.value", "")
68
+
69
+ span_info = {
70
+ "tool_name": tool_name,
71
+ "input": attributes.get("input.value", ""),
72
+ "output": tool_output,
73
+ "status": span.get("status", {}).get("status_code"),
74
+ }
75
+ span_info["input"] = json.loads(span_info["input"])
76
+
77
+ calls.append(span_info)
78
+
79
+ return calls
80
+
81
+
82
+ # Backward compatibility functions that use the new class structure
83
+ def extract_hypothesis_answer(
84
+ trace: List[Dict[str, Any]], agent_type: AgentType
85
+ ) -> str:
86
+ """Extract the hypothesis agent final answer from the trace"""
87
+ processor = TelemetryProcessor.create(agent_type)
88
+ return processor.extract_hypothesis_answer(trace)
89
+
90
+
91
+ def parse_generic_key_value_string(text: str) -> Dict[str, str]:
92
+ """
93
+ Parse a string that has items of a dict with key-value pairs separated by '='.
94
+ Only splits on '=' signs, handling quoted strings properly.
95
+ """
96
+ return TelemetryProcessor.parse_generic_key_value_string(text)
97
+
98
+
99
+ def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
100
+ """Extract relevant telemetry evidence based on the agent type."""
101
+ processor = TelemetryProcessor.create(agent_type)
102
+ return processor.extract_evidence(telemetry)
src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import json
3
+
4
+ from surf_spot_finder.agents import AgentType
5
+ from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
6
+
7
+
8
+ class SmolagentsTelemetryProcessor(TelemetryProcessor):
9
+ """Processor for SmoL Agents telemetry data."""
10
+
11
+ def _get_agent_type(self) -> AgentType:
12
+ return AgentType.SMOLAGENTS
13
+
14
+ def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
15
+ for span in reversed(trace):
16
+ if span["attributes"]["openinference.span.kind"] == "AGENT":
17
+ content = span["attributes"]["output.value"]
18
+ return content
19
+
20
+ raise ValueError("No agent final answer found in trace")
21
+
22
+ def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
23
+ """Extract LLM calls and tool calls from SmoL Agents telemetry."""
24
+ calls = []
25
+
26
+ for span in telemetry:
27
+ # Skip spans without attributes
28
+ if "attributes" not in span:
29
+ continue
30
+
31
+ attributes = span["attributes"]
32
+
33
+ # Extract tool information
34
+ if "tool.name" in attributes or span.get("name", "").startswith(
35
+ "SimpleTool"
36
+ ):
37
+ tool_info = {
38
+ "tool_name": attributes.get(
39
+ "tool.name", span.get("name", "Unknown tool")
40
+ ),
41
+ "status": "success"
42
+ if span.get("status", {}).get("status_code") == "OK"
43
+ else "error",
44
+ "error": span.get("status", {}).get("description", None),
45
+ }
46
+
47
+ # Extract input if available
48
+ if "input.value" in attributes:
49
+ try:
50
+ input_value = json.loads(attributes["input.value"])
51
+ if "kwargs" in input_value:
52
+ # For SmoLAgents, the actual input is often in the kwargs field
53
+ tool_info["input"] = input_value["kwargs"]
54
+ else:
55
+ tool_info["input"] = input_value
56
+ except (json.JSONDecodeError, TypeError):
57
+ tool_info["input"] = attributes["input.value"]
58
+
59
+ # Extract output if available
60
+ if "output.value" in attributes:
61
+ try:
62
+ # Try to parse JSON output
63
+ output_value = (
64
+ json.loads(attributes["output.value"])
65
+ if isinstance(attributes["output.value"], str)
66
+ else attributes["output.value"]
67
+ )
68
+ tool_info["output"] = output_value
69
+ except (json.JSONDecodeError, TypeError):
70
+ tool_info["output"] = attributes["output.value"]
71
+ else:
72
+ tool_info["output"] = "No output found"
73
+
74
+ calls.append(tool_info)
75
+
76
+ # Extract LLM calls to see reasoning
77
+ elif "LiteLLMModel.__call__" in span.get("name", ""):
78
+ # The LLM output may be in different places depending on the implementation
79
+ output_content = None
80
+
81
+ # Try to get the output from the llm.output_messages.0.message.content attribute
82
+ if "llm.output_messages.0.message.content" in attributes:
83
+ output_content = attributes["llm.output_messages.0.message.content"]
84
+
85
+ # Or try to parse it from the output.value as JSON
86
+ elif "output.value" in attributes:
87
+ try:
88
+ output_value = json.loads(attributes["output.value"])
89
+ if "content" in output_value:
90
+ output_content = output_value["content"]
91
+ except (json.JSONDecodeError, TypeError):
92
+ pass
93
+
94
+ if output_content:
95
+ calls.append(
96
+ {
97
+ "model": attributes.get("llm.model_name", "Unknown model"),
98
+ "output": output_content,
99
+ "type": "reasoning",
100
+ }
101
+ )
102
+
103
+ return calls
src/surf_spot_finder/evaluation/telemetry/telemetry.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, ClassVar
2
+ import json
3
+ import re
4
+ from abc import ABC, abstractmethod
5
+ from loguru import logger
6
+ from surf_spot_finder.agents import AgentType
7
+
8
+
9
+ class TelemetryProcessor(ABC):
10
+ """Base class for processing telemetry data from different agent types."""
11
+
12
+ MAX_EVIDENCE_LENGTH: ClassVar[int] = 400
13
+
14
+ @classmethod
15
+ def create(cls, agent_type: AgentType) -> "TelemetryProcessor":
16
+ """Factory method to create the appropriate telemetry processor."""
17
+ if agent_type == AgentType.LANGCHAIN:
18
+ from surf_spot_finder.evaluation.telemetry.langchain_telemetry import (
19
+ LangchainTelemetryProcessor,
20
+ )
21
+
22
+ return LangchainTelemetryProcessor()
23
+ elif agent_type == AgentType.SMOLAGENTS:
24
+ from surf_spot_finder.evaluation.telemetry.smolagents_telemetry import (
25
+ SmolagentsTelemetryProcessor,
26
+ )
27
+
28
+ return SmolagentsTelemetryProcessor()
29
+ elif agent_type == AgentType.OPENAI:
30
+ from surf_spot_finder.evaluation.telemetry.openai_telemetry import (
31
+ OpenAITelemetryProcessor,
32
+ )
33
+
34
+ return OpenAITelemetryProcessor()
35
+ else:
36
+ raise ValueError(f"Unsupported agent type {agent_type}")
37
+
38
+ @staticmethod
39
+ def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
40
+ """Determine the agent type based on the trace.
41
+ These are not really stable ways to find it, because we're waiting on some
42
+ reliable method for determining the agent type. This is a temporary solution.
43
+ """
44
+ for span in trace:
45
+ if "langchain" in span.get("attributes", {}).get("input.value", ""):
46
+ logger.info("Agent type is LANGCHAIN")
47
+ return AgentType.LANGCHAIN
48
+ if span.get("attributes", {}).get("smolagents.max_steps"):
49
+ logger.info("Agent type is SMOLAGENTS")
50
+ return AgentType.SMOLAGENTS
51
+ # This is extremely fragile but there currently isn't
52
+ # any specific key to indicate the agent type
53
+ if span.get("name") == "response":
54
+ logger.info("Agent type is OPENAI")
55
+ return AgentType.OPENAI
56
+ raise ValueError(
57
+ "Could not determine agent type from trace, or agent type not supported"
58
+ )
59
+
60
+ @abstractmethod
61
+ def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
62
+ """Extract the hypothesis agent final answer from the trace."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
67
+ """Extract the agent-specific data from telemetry."""
68
+ pass
69
+
70
+ def extract_evidence(self, telemetry: List[Dict[str, Any]]) -> str:
71
+ """Extract relevant telemetry evidence."""
72
+ calls = self._extract_telemetry_data(telemetry)
73
+ return self._format_evidence(calls)
74
+
75
+ def _format_evidence(self, calls: List[Dict]) -> str:
76
+ """Format extracted data into a standardized output format."""
77
+ evidence = f"## {self._get_agent_type().name} Agent Execution\n\n"
78
+
79
+ for idx, call in enumerate(calls, start=1):
80
+ evidence += f"### Call {idx}\n"
81
+
82
+ # Truncate any values that are too long
83
+ call = {
84
+ k: (
85
+ v[: self.MAX_EVIDENCE_LENGTH] + "..."
86
+ if isinstance(v, str) and len(v) > self.MAX_EVIDENCE_LENGTH
87
+ else v
88
+ )
89
+ for k, v in call.items()
90
+ }
91
+
92
+ # Use ensure_ascii=False to prevent escaping Unicode characters
93
+ evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
94
+
95
+ return evidence
96
+
97
+ @abstractmethod
98
+ def _get_agent_type(self) -> AgentType:
99
+ """Get the agent type associated with this processor."""
100
+ pass
101
+
102
+ @staticmethod
103
+ def parse_generic_key_value_string(text: str) -> Dict[str, str]:
104
+ """
105
+ Parse a string that has items of a dict with key-value pairs separated by '='.
106
+ Only splits on '=' signs, handling quoted strings properly.
107
+ """
108
+ pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
109
+ result = {}
110
+
111
+ matches = re.findall(pattern, text)
112
+ for key, value in matches:
113
+ # Clean up the key
114
+ key = key.strip()
115
+
116
+ # Clean up the value - remove surrounding quotes if present
117
+ if (value.startswith("'") and value.endswith("'")) or (
118
+ value.startswith('"') and value.endswith('"')
119
+ ):
120
+ value = value[1:-1]
121
+
122
+ # Store in result dictionary
123
+ result[key] = value
124
+
125
+ return result
src/surf_spot_finder/evaluation/telemetry_utils.py DELETED
@@ -1,301 +0,0 @@
1
- from typing import Any, Dict, List
2
- import json
3
- from langchain_core.messages import BaseMessage
4
- import re
5
-
6
- from surf_spot_finder.agents import AgentType
7
-
8
-
9
- def extract_hypothesis_answer(
10
- trace: List[Dict[str, Any]], agent_type: AgentType
11
- ) -> str:
12
- """Extract the hypothesis agent final answer from the trace"""
13
- for span in reversed(trace):
14
- if agent_type == AgentType.LANGCHAIN:
15
- if span["attributes"]["openinference.span.kind"] == "AGENT":
16
- content = span["attributes"]["output.value"]
17
- # If it's langchain, the actual content is a serialized langchain message that we need to extract.
18
- message = json.loads(content)["messages"][0]
19
- message = parse_generic_key_value_string(message)
20
- base_message = BaseMessage(**message, type="AGENT")
21
- print(base_message.text())
22
- return base_message.text()
23
- elif agent_type == AgentType.SMOLAGENTS:
24
- if span["attributes"]["openinference.span.kind"] == "AGENT":
25
- content = span["attributes"]["output.value"]
26
- # If it's langchain, the actual content is a serialized langchain message that we need to extract.
27
- return content
28
- elif agent_type == AgentType.OPENAI:
29
- # Looking for the final response that has the summary answer
30
- if (
31
- "attributes" in span
32
- and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
33
- ):
34
- output_key = (
35
- "llm.output_messages.0.message.contents.0.message_content.text"
36
- )
37
- if output_key in span["attributes"]:
38
- return span["attributes"][output_key]
39
- else:
40
- raise ValueError(f"Unsupported agent type {agent_type}")
41
- raise ValueError("No agent final answer found in trace")
42
-
43
-
44
- def parse_generic_key_value_string(text):
45
- """
46
- Parse a string that has items of a dict with key-value pairs separated by '='.
47
- Only splits on '=' signs, handling quoted strings properly.
48
- I think this is to compensate for a bug in openinference? https://github.com/Arize-ai/openinference/issues/1401
49
- """
50
-
51
- # Pattern to match key=value pairs, handling quoted values
52
- # This regex looks for word characters followed by = and then captures everything
53
- # until it finds another word character followed by = or the end of the string
54
- # Claude helped me with this one, regex is hard
55
- pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
56
-
57
- result = {}
58
-
59
- matches = re.findall(pattern, text)
60
- for key, value in matches:
61
- # Clean up the key
62
- key = key.strip()
63
-
64
- # Clean up the value - remove surrounding quotes if present
65
- if (value.startswith("'") and value.endswith("'")) or (
66
- value.startswith('"') and value.endswith('"')
67
- ):
68
- value = value[1:-1]
69
-
70
- # Store in result dictionary
71
- result[key] = value
72
-
73
- return result
74
-
75
-
76
- def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
77
- """Extract relevant telemetry evidence based on the agent type."""
78
- # Data extraction function for each agent type
79
- extractors = {
80
- AgentType.SMOLAGENTS: _extract_smolagents_data,
81
- AgentType.LANGCHAIN: _extract_langchain_data,
82
- AgentType.OPENAI: _extract_openai_data,
83
- }
84
-
85
- if agent_type not in extractors:
86
- raise ValueError(f"Unsupported agent type {agent_type}")
87
-
88
- # Extract raw data from telemetry
89
- calls = extractors[agent_type](telemetry)
90
-
91
- # Format data into a consistent structure
92
- return _format_evidence(calls, agent_type)
93
-
94
-
95
- def _extract_smolagents_data(telemetry: List[Dict[str, Any]]) -> List[Dict]:
96
- """Extract LLM calls and tool calls from SmoL Agents telemetry."""
97
- calls = []
98
-
99
- for span in telemetry:
100
- # Skip spans without attributes
101
- if "attributes" not in span:
102
- continue
103
-
104
- attributes = span["attributes"]
105
-
106
- # Extract tool information
107
- if "tool.name" in attributes or span.get("name", "").startswith("SimpleTool"):
108
- tool_info = {
109
- "tool_name": attributes.get(
110
- "tool.name", span.get("name", "Unknown tool")
111
- ),
112
- "status": "success"
113
- if span.get("status", {}).get("status_code") == "OK"
114
- else "error",
115
- "error": span.get("status", {}).get("description", None),
116
- }
117
-
118
- # Extract input if available
119
- if "input.value" in attributes:
120
- try:
121
- input_value = json.loads(attributes["input.value"])
122
- if "kwargs" in input_value:
123
- # For SmoLAgents, the actual input is often in the kwargs field
124
- tool_info["input"] = input_value["kwargs"]
125
- else:
126
- tool_info["input"] = input_value
127
- except (json.JSONDecodeError, TypeError):
128
- tool_info["input"] = attributes["input.value"]
129
-
130
- # Extract output if available
131
- if "output.value" in attributes:
132
- try:
133
- # Try to parse JSON output
134
- output_value = (
135
- json.loads(attributes["output.value"])
136
- if isinstance(attributes["output.value"], str)
137
- else attributes["output.value"]
138
- )
139
- tool_info["output"] = output_value
140
- except (json.JSONDecodeError, TypeError):
141
- tool_info["output"] = attributes["output.value"]
142
- else:
143
- tool_info["output"] = "No output found"
144
-
145
- calls.append(tool_info)
146
-
147
- # Extract LLM calls to see reasoning
148
- elif "LiteLLMModel.__call__" in span.get("name", ""):
149
- # The LLM output may be in different places depending on the implementation
150
- output_content = None
151
-
152
- # Try to get the output from the llm.output_messages.0.message.content attribute
153
- if "llm.output_messages.0.message.content" in attributes:
154
- output_content = attributes["llm.output_messages.0.message.content"]
155
-
156
- # Or try to parse it from the output.value as JSON
157
- elif "output.value" in attributes:
158
- try:
159
- output_value = json.loads(attributes["output.value"])
160
- if "content" in output_value:
161
- output_content = output_value["content"]
162
- except (json.JSONDecodeError, TypeError):
163
- pass
164
-
165
- if output_content:
166
- calls.append(
167
- {
168
- "model": attributes.get("llm.model_name", "Unknown model"),
169
- "output": output_content,
170
- "type": "reasoning",
171
- }
172
- )
173
-
174
- return calls
175
-
176
-
177
- def _extract_langchain_data(telemetry: List[Dict[str, Any]]) -> List:
178
- """Extract LLM calls and tool calls from LangChain telemetry."""
179
- calls = []
180
-
181
- for span in telemetry:
182
- if "attributes" not in span:
183
- continue
184
-
185
- attributes = span.get("attributes", {})
186
- span_kind = attributes.get("openinference.span.kind", "")
187
-
188
- # Collect LLM calls
189
- if span_kind == "LLM" and "llm.output_messages.0.message.content" in attributes:
190
- llm_info = {
191
- "model": attributes.get("llm.model_name", "Unknown model"),
192
- "input": attributes.get("llm.input_messages.0.message.content", ""),
193
- "output": attributes.get("llm.output_messages.0.message.content", ""),
194
- "type": "reasoning",
195
- }
196
- calls.append(llm_info)
197
-
198
- # Try to find tool calls
199
- if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
200
- tool_info = {
201
- "tool_name": attributes.get(
202
- "tool.name", span.get("name", "Unknown tool")
203
- ),
204
- "status": "success"
205
- if span.get("status", {}).get("status_code") == "OK"
206
- else "error",
207
- "error": span.get("status", {}).get("description", None),
208
- }
209
-
210
- if "input.value" in attributes:
211
- try:
212
- input_value = json.loads(attributes["input.value"])
213
- tool_info["input"] = input_value
214
- except Exception:
215
- tool_info["input"] = attributes["input.value"]
216
-
217
- if "output.value" in attributes:
218
- tool_info["output"] = parse_generic_key_value_string(
219
- json.loads(attributes["output.value"])["output"]
220
- )["content"]
221
-
222
- calls.append(tool_info)
223
-
224
- return calls
225
-
226
-
227
- def _extract_openai_data(telemetry: List[Dict[str, Any]]) -> list:
228
- """Extract LLM calls and tool calls from OpenAI telemetry."""
229
- calls = []
230
-
231
- for span in telemetry:
232
- if "attributes" not in span:
233
- continue
234
-
235
- attributes = span.get("attributes", {})
236
- span_kind = attributes.get("openinference.span.kind", "")
237
-
238
- # Collect LLM interactions - look for direct message content first
239
- if span_kind == "LLM":
240
- # Initialize the LLM info dictionary
241
- span_info = {}
242
-
243
- # Try to get input message
244
- input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
245
- if input_key in attributes:
246
- span_info["input"] = attributes[input_key]
247
-
248
- # Try to get output message directly
249
- output_content = None
250
- # Try in multiple possible locations
251
- for key in [
252
- "llm.output_messages.0.message.content",
253
- "llm.output_messages.0.message.contents.0.message_content.text",
254
- ]:
255
- if key in attributes:
256
- output_content = attributes[key]
257
- break
258
-
259
- # If we found direct output content, use it
260
- if output_content:
261
- span_info["output"] = output_content
262
- calls.append(span_info)
263
- elif span_kind == "TOOL":
264
- tool_name = attributes.get("tool.name", "Unknown tool")
265
- tool_output = attributes.get("output.value", "")
266
-
267
- span_info = {
268
- "tool_name": tool_name,
269
- "input": attributes.get("input.value", ""),
270
- "output": tool_output,
271
- "status": span.get("status", {}).get("status_code"),
272
- }
273
- span_info["input"] = json.loads(span_info["input"])
274
-
275
- calls.append(span_info)
276
-
277
- return calls
278
-
279
-
280
- def _format_evidence(calls: List[Dict], agent_type: AgentType) -> str:
281
- """Format extracted data into a standardized output format."""
282
- evidence = f"## {agent_type.name} Agent Execution\n\n"
283
-
284
- for idx, call in enumerate(calls, start=1):
285
- evidence += f"### Call {idx}\n"
286
-
287
- # Truncate any values that are too long
288
- max_length = 400
289
- call = {
290
- k: (
291
- v[:max_length] + "..."
292
- if isinstance(v, str) and len(v) > max_length
293
- else v
294
- )
295
- for k, v in call.items()
296
- }
297
-
298
- # Use ensure_ascii=False to prevent escaping Unicode characters
299
- evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
300
-
301
- return evidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/surf_spot_finder/evaluation/test_case.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Any
2
  from pydantic import BaseModel, Field, ConfigDict
3
  import yaml
4
 
@@ -13,14 +13,6 @@ class InputModel(BaseModel):
13
  json_tracer: bool
14
 
15
 
16
- class AgentModel(BaseModel):
17
- model_id: str
18
- api_key_var: str = "OPENAI_API_KEY"
19
- api_base: Optional[str] = None
20
- agent_type: str
21
- tools: Optional[List[str]] = None
22
-
23
-
24
  class CheckpointCriteria(BaseModel):
25
  """Represents a checkpoint criteria with a description"""
26
 
@@ -32,20 +24,15 @@ class CheckpointCriteria(BaseModel):
32
  class TestCase(BaseModel):
33
  model_config = ConfigDict(extra="forbid")
34
  input: InputModel
35
- agent: AgentModel
36
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
37
  checkpoints: List[CheckpointCriteria] = Field(default_factory=list)
38
  final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
39
 
40
  @classmethod
41
- def from_yaml(cls, test_case_path: str, agent_config_path: str) -> "TestCase":
42
  """Load a test case from a YAML file and process it"""
43
  with open(test_case_path, "r") as f:
44
  test_case_dict = yaml.safe_load(f)
45
-
46
- with open(agent_config_path, "r") as f:
47
- agent_config_dict = yaml.safe_load(f)
48
- test_case_dict["agent"] = agent_config_dict["agent"]
49
  final_answer_criteria = []
50
 
51
  def add_gt_final_answer_criteria(ground_truth_list):
 
1
+ from typing import Dict, List, Any
2
  from pydantic import BaseModel, Field, ConfigDict
3
  import yaml
4
 
 
13
  json_tracer: bool
14
 
15
 
 
 
 
 
 
 
 
 
16
  class CheckpointCriteria(BaseModel):
17
  """Represents a checkpoint criteria with a description"""
18
 
 
24
  class TestCase(BaseModel):
25
  model_config = ConfigDict(extra="forbid")
26
  input: InputModel
 
27
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
28
  checkpoints: List[CheckpointCriteria] = Field(default_factory=list)
29
  final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
30
 
31
  @classmethod
32
+ def from_yaml(cls, test_case_path: str) -> "TestCase":
33
  """Load a test case from a YAML file and process it"""
34
  with open(test_case_path, "r") as f:
35
  test_case_dict = yaml.safe_load(f)
 
 
 
 
36
  final_answer_criteria = []
37
 
38
  def add_gt_final_answer_criteria(ground_truth_list):
src/surf_spot_finder/evaluation/utils.py CHANGED
@@ -4,36 +4,11 @@ import re
4
 
5
  from litellm import completion
6
  from textwrap import dedent
7
- from loguru import logger
8
 
9
  from pydantic import BaseModel, ConfigDict
10
- from surf_spot_finder.evaluation.telemetry_utils import extract_evidence
11
  from surf_spot_finder.evaluation.test_case import CheckpointCriteria
12
 
13
- from surf_spot_finder.agents import AgentType
14
-
15
-
16
- def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
17
- """Determine the agent type based on the trace.
18
- These are not really stable ways to find it, because we're waiting on some
19
- reliable method for determining the agent type. This is a temporary solution.
20
- """
21
- for span in trace:
22
- if "langchain" in span.get("attributes", {}).get("input.value", ""):
23
- logger.info("Agent type is LANGCHAIN")
24
- return AgentType.LANGCHAIN
25
- if span.get("attributes", {}).get("smolagents.max_steps"):
26
- logger.info("Agent type is SMOLAGENTS")
27
- return AgentType.SMOLAGENTS
28
- # This is extremely fragile but there currently isn't
29
- # any specific key to indicate the agent type
30
- if span.get("name") == "response":
31
- logger.info("Agent type is OPENAI")
32
- return AgentType.OPENAI
33
- raise ValueError(
34
- "Could not determine agent type from trace, or agent type not supported"
35
- )
36
-
37
 
38
  class EvaluationResult(BaseModel):
39
  """Represents the result of evaluating a criterion"""
@@ -126,7 +101,7 @@ def verify_checkpoints(
126
  telemetry: List[Dict[str, Any]],
127
  checkpoints: List[CheckpointCriteria],
128
  model: str,
129
- agent_type: AgentType,
130
  ) -> List[EvaluationResult]:
131
  """Verify each checkpoint against the telemetry data using LLM
132
  These checkpoints do not take the ground truth or hyupothesis
@@ -134,8 +109,7 @@ def verify_checkpoints(
134
  the specific criteria mentioned.
135
  """
136
  results = []
137
-
138
- evidence = extract_evidence(telemetry, agent_type)
139
  print(evidence)
140
  for checkpoint in checkpoints:
141
  criteria = checkpoint.criteria
 
4
 
5
  from litellm import completion
6
  from textwrap import dedent
 
7
 
8
  from pydantic import BaseModel, ConfigDict
9
+ from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
10
  from surf_spot_finder.evaluation.test_case import CheckpointCriteria
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class EvaluationResult(BaseModel):
14
  """Represents the result of evaluating a criterion"""
 
101
  telemetry: List[Dict[str, Any]],
102
  checkpoints: List[CheckpointCriteria],
103
  model: str,
104
+ processor: TelemetryProcessor,
105
  ) -> List[EvaluationResult]:
106
  """Verify each checkpoint against the telemetry data using LLM
107
  These checkpoints do not take the ground truth or hyupothesis
 
109
  the specific criteria mentioned.
110
  """
111
  results = []
112
+ evidence = processor.extract_evidence(telemetry)
 
113
  print(evidence)
114
  for checkpoint in checkpoints:
115
  criteria = checkpoint.criteria