Spaces:
Running
Running
Nathan Brake
commited on
Split telemetry processing into cleaner classes, support ollama (#31)
Browse files- .pylintrc +2 -0
- examples/langchain_single_agent_vertical.yaml +1 -0
- pyproject.toml +1 -0
- src/surf_spot_finder/agents/langchain.py +5 -2
- src/surf_spot_finder/evaluation/evaluate.py +20 -27
- src/surf_spot_finder/evaluation/telemetry/__init__.py +3 -0
- src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py +80 -0
- src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py +102 -0
- src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py +103 -0
- src/surf_spot_finder/evaluation/telemetry/telemetry.py +125 -0
- src/surf_spot_finder/evaluation/telemetry_utils.py +0 -301
- src/surf_spot_finder/evaluation/test_case.py +2 -15
- src/surf_spot_finder/evaluation/utils.py +3 -29
.pylintrc
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
[MESSAGES CONTROL]
|
2 |
+
disable=C0415
|
examples/langchain_single_agent_vertical.yaml
CHANGED
@@ -5,6 +5,7 @@ input:
|
|
5 |
# input_prompt_template:
|
6 |
agent:
|
7 |
model_id: o3-mini
|
|
|
8 |
agent_type: langchain
|
9 |
tools:
|
10 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
|
|
5 |
# input_prompt_template:
|
6 |
agent:
|
7 |
model_id: o3-mini
|
8 |
+
# model_id: ollama/llama3.2:3b
|
9 |
agent_type: langchain
|
10 |
tools:
|
11 |
- "surf_spot_finder.tools.driving_hours_to_meters"
|
pyproject.toml
CHANGED
@@ -21,6 +21,7 @@ langchain = [
|
|
21 |
"langchain",
|
22 |
"langgraph",
|
23 |
"langchain-openai>=0.3.9",
|
|
|
24 |
"openinference-instrumentation-langchain"
|
25 |
]
|
26 |
smolagents = [
|
|
|
21 |
"langchain",
|
22 |
"langgraph",
|
23 |
"langchain-openai>=0.3.9",
|
24 |
+
"langchain-ollama>=0.3.0",
|
25 |
"openinference-instrumentation-langchain"
|
26 |
]
|
27 |
smolagents = [
|
src/surf_spot_finder/agents/langchain.py
CHANGED
@@ -51,8 +51,11 @@ def run_lanchain_agent(
|
|
51 |
if not isinstance(imported_tool, BaseTool):
|
52 |
imported_tool = tool(imported_tool)
|
53 |
imported_tools.append((imported_tool))
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
agent = create_react_agent(
|
57 |
model=model, tools=imported_tools, checkpointer=MemorySaver()
|
58 |
)
|
|
|
51 |
if not isinstance(imported_tool, BaseTool):
|
52 |
imported_tool = tool(imported_tool)
|
53 |
imported_tools.append((imported_tool))
|
54 |
+
if "/" in model_id:
|
55 |
+
model_provider, model_id = model_id.split("/")
|
56 |
+
model = init_chat_model(model_id, model_provider=model_provider)
|
57 |
+
else:
|
58 |
+
model = init_chat_model(model_id)
|
59 |
agent = create_react_agent(
|
60 |
model=model, tools=imported_tools, checkpointer=MemorySaver()
|
61 |
)
|
src/surf_spot_finder/evaluation/evaluate.py
CHANGED
@@ -8,13 +8,11 @@ from surf_spot_finder.cli import find_surf_spot
|
|
8 |
from surf_spot_finder.config import (
|
9 |
Config,
|
10 |
)
|
11 |
-
from surf_spot_finder.
|
12 |
from surf_spot_finder.evaluation.utils import (
|
13 |
-
determine_agent_type,
|
14 |
verify_checkpoints,
|
15 |
verify_hypothesis_answer,
|
16 |
)
|
17 |
-
from surf_spot_finder.evaluation.telemetry_utils import extract_hypothesis_answer
|
18 |
from surf_spot_finder.evaluation.test_case import TestCase
|
19 |
|
20 |
logger.remove()
|
@@ -22,22 +20,15 @@ logger = logger.opt(ansi=True)
|
|
22 |
logger.add(sys.stdout, colorize=True, format="{message}")
|
23 |
|
24 |
|
25 |
-
def run_agent(test_case: TestCase) -> str:
|
26 |
input_data = test_case.input
|
27 |
-
|
28 |
logger.info("Loading config")
|
29 |
-
config = Config(
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
api_key_var=agent_config.api_key_var,
|
35 |
-
prompt=INPUT_PROMPT,
|
36 |
-
json_tracer=input_data.json_tracer,
|
37 |
-
api_base=agent_config.api_base,
|
38 |
-
agent_type=agent_config.agent_type,
|
39 |
-
tools=agent_config.tools,
|
40 |
-
)
|
41 |
return find_surf_spot(
|
42 |
location=config.location,
|
43 |
date=config.date,
|
@@ -57,12 +48,11 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
57 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
58 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
59 |
|
60 |
-
agent_type = determine_agent_type(telemetry)
|
61 |
|
62 |
# Extract the final answer from the telemetry
|
63 |
-
|
64 |
-
|
65 |
-
)
|
66 |
logger.info(
|
67 |
f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
|
68 |
)
|
@@ -72,7 +62,7 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
72 |
telemetry=telemetry,
|
73 |
checkpoints=test_case.checkpoints,
|
74 |
model=llm_judge,
|
75 |
-
|
76 |
)
|
77 |
|
78 |
hypothesis_answer_results = verify_hypothesis_answer(
|
@@ -116,7 +106,9 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
116 |
|
117 |
|
118 |
def evaluate(
|
119 |
-
test_case_path: str,
|
|
|
|
|
120 |
) -> None:
|
121 |
"""
|
122 |
Evaluate agent performance using either a provided telemetry file or by running the agent.
|
@@ -125,15 +117,16 @@ def evaluate(
|
|
125 |
telemetry_path: Optional path to an existing telemetry file. If not provided,
|
126 |
the agent will be run to generate one.
|
127 |
"""
|
128 |
-
test_case = TestCase.from_yaml(
|
129 |
-
test_case_path=test_case_path, agent_config_path=agent_config_path
|
130 |
-
)
|
131 |
|
132 |
if telemetry_path is None:
|
133 |
logger.info(
|
134 |
"No telemetry path provided. Running agent to generate telemetry..."
|
135 |
)
|
136 |
-
|
|
|
|
|
|
|
137 |
else:
|
138 |
logger.info(f"Using provided telemetry file: {telemetry_path}")
|
139 |
logger.info(
|
|
|
8 |
from surf_spot_finder.config import (
|
9 |
Config,
|
10 |
)
|
11 |
+
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
12 |
from surf_spot_finder.evaluation.utils import (
|
|
|
13 |
verify_checkpoints,
|
14 |
verify_hypothesis_answer,
|
15 |
)
|
|
|
16 |
from surf_spot_finder.evaluation.test_case import TestCase
|
17 |
|
18 |
logger.remove()
|
|
|
20 |
logger.add(sys.stdout, colorize=True, format="{message}")
|
21 |
|
22 |
|
23 |
+
def run_agent(test_case: TestCase, agent_config_path: str) -> str:
|
24 |
input_data = test_case.input
|
25 |
+
|
26 |
logger.info("Loading config")
|
27 |
+
config = Config.from_yaml(agent_config_path)
|
28 |
+
config.location = input_data.location
|
29 |
+
config.date = input_data.date
|
30 |
+
config.max_driving_hours = input_data.max_driving_hours
|
31 |
+
config.json_tracer = input_data.json_tracer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return find_surf_spot(
|
33 |
location=config.location,
|
34 |
date=config.date,
|
|
|
48 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
49 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
50 |
|
51 |
+
agent_type = TelemetryProcessor.determine_agent_type(telemetry)
|
52 |
|
53 |
# Extract the final answer from the telemetry
|
54 |
+
processor = TelemetryProcessor.create(agent_type)
|
55 |
+
hypothesis_answer = processor.extract_hypothesis_answer(trace=telemetry)
|
|
|
56 |
logger.info(
|
57 |
f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
|
58 |
)
|
|
|
62 |
telemetry=telemetry,
|
63 |
checkpoints=test_case.checkpoints,
|
64 |
model=llm_judge,
|
65 |
+
processor=processor,
|
66 |
)
|
67 |
|
68 |
hypothesis_answer_results = verify_hypothesis_answer(
|
|
|
106 |
|
107 |
|
108 |
def evaluate(
|
109 |
+
test_case_path: str,
|
110 |
+
agent_config_path: str = None,
|
111 |
+
telemetry_path: Optional[str] = None,
|
112 |
) -> None:
|
113 |
"""
|
114 |
Evaluate agent performance using either a provided telemetry file or by running the agent.
|
|
|
117 |
telemetry_path: Optional path to an existing telemetry file. If not provided,
|
118 |
the agent will be run to generate one.
|
119 |
"""
|
120 |
+
test_case = TestCase.from_yaml(test_case_path=test_case_path)
|
|
|
|
|
121 |
|
122 |
if telemetry_path is None:
|
123 |
logger.info(
|
124 |
"No telemetry path provided. Running agent to generate telemetry..."
|
125 |
)
|
126 |
+
assert (
|
127 |
+
agent_config_path is not None
|
128 |
+
), "Agent config path must be provided if running agent"
|
129 |
+
telemetry_path = run_agent(test_case, agent_config_path)
|
130 |
else:
|
131 |
logger.info(f"Using provided telemetry file: {telemetry_path}")
|
132 |
logger.info(
|
src/surf_spot_finder/evaluation/telemetry/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .telemetry import TelemetryProcessor
|
2 |
+
|
3 |
+
__all__ = ["TelemetryProcessor"]
|
src/surf_spot_finder/evaluation/telemetry/langchain_telemetry.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
import json
|
3 |
+
from langchain_core.messages import BaseMessage
|
4 |
+
|
5 |
+
from surf_spot_finder.agents import AgentType
|
6 |
+
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
7 |
+
|
8 |
+
|
9 |
+
class LangchainTelemetryProcessor(TelemetryProcessor):
|
10 |
+
"""Processor for Langchain agent telemetry data."""
|
11 |
+
|
12 |
+
def _get_agent_type(self) -> AgentType:
|
13 |
+
return AgentType.LANGCHAIN
|
14 |
+
|
15 |
+
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
16 |
+
for span in reversed(trace):
|
17 |
+
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
18 |
+
content = span["attributes"]["output.value"]
|
19 |
+
# Extract content from serialized langchain message
|
20 |
+
message = json.loads(content)["messages"][0]
|
21 |
+
message = self.parse_generic_key_value_string(message)
|
22 |
+
base_message = BaseMessage(content=message["content"], type="AGENT")
|
23 |
+
print(base_message.text())
|
24 |
+
return base_message.text()
|
25 |
+
|
26 |
+
raise ValueError("No agent final answer found in trace")
|
27 |
+
|
28 |
+
def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
|
29 |
+
"""Extract LLM calls and tool calls from LangChain telemetry."""
|
30 |
+
calls = []
|
31 |
+
|
32 |
+
for span in telemetry:
|
33 |
+
if "attributes" not in span:
|
34 |
+
continue
|
35 |
+
|
36 |
+
attributes = span.get("attributes", {})
|
37 |
+
span_kind = attributes.get("openinference.span.kind", "")
|
38 |
+
|
39 |
+
# Collect LLM calls
|
40 |
+
if (
|
41 |
+
span_kind == "LLM"
|
42 |
+
and "llm.output_messages.0.message.content" in attributes
|
43 |
+
):
|
44 |
+
llm_info = {
|
45 |
+
"model": attributes.get("llm.model_name", "Unknown model"),
|
46 |
+
"input": attributes.get("llm.input_messages.0.message.content", ""),
|
47 |
+
"output": attributes.get(
|
48 |
+
"llm.output_messages.0.message.content", ""
|
49 |
+
),
|
50 |
+
"type": "reasoning",
|
51 |
+
}
|
52 |
+
calls.append(llm_info)
|
53 |
+
|
54 |
+
# Try to find tool calls
|
55 |
+
if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
|
56 |
+
tool_info = {
|
57 |
+
"tool_name": attributes.get(
|
58 |
+
"tool.name", span.get("name", "Unknown tool")
|
59 |
+
),
|
60 |
+
"status": "success"
|
61 |
+
if span.get("status", {}).get("status_code") == "OK"
|
62 |
+
else "error",
|
63 |
+
"error": span.get("status", {}).get("description", None),
|
64 |
+
}
|
65 |
+
|
66 |
+
if "input.value" in attributes:
|
67 |
+
try:
|
68 |
+
input_value = json.loads(attributes["input.value"])
|
69 |
+
tool_info["input"] = input_value
|
70 |
+
except Exception:
|
71 |
+
tool_info["input"] = attributes["input.value"]
|
72 |
+
|
73 |
+
if "output.value" in attributes:
|
74 |
+
tool_info["output"] = self.parse_generic_key_value_string(
|
75 |
+
json.loads(attributes["output.value"])["output"]
|
76 |
+
)["content"]
|
77 |
+
|
78 |
+
calls.append(tool_info)
|
79 |
+
|
80 |
+
return calls
|
src/surf_spot_finder/evaluation/telemetry/openai_telemetry.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
import json
|
3 |
+
|
4 |
+
from surf_spot_finder.agents import AgentType
|
5 |
+
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
+
|
7 |
+
|
8 |
+
class OpenAITelemetryProcessor(TelemetryProcessor):
|
9 |
+
"""Processor for OpenAI agent telemetry data."""
|
10 |
+
|
11 |
+
def _get_agent_type(self) -> AgentType:
|
12 |
+
return AgentType.OPENAI
|
13 |
+
|
14 |
+
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
15 |
+
for span in reversed(trace):
|
16 |
+
# Looking for the final response that has the summary answer
|
17 |
+
if (
|
18 |
+
"attributes" in span
|
19 |
+
and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
|
20 |
+
):
|
21 |
+
output_key = (
|
22 |
+
"llm.output_messages.0.message.contents.0.message_content.text"
|
23 |
+
)
|
24 |
+
if output_key in span["attributes"]:
|
25 |
+
return span["attributes"][output_key]
|
26 |
+
|
27 |
+
raise ValueError("No agent final answer found in trace")
|
28 |
+
|
29 |
+
def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> list:
|
30 |
+
"""Extract LLM calls and tool calls from OpenAI telemetry."""
|
31 |
+
calls = []
|
32 |
+
|
33 |
+
for span in telemetry:
|
34 |
+
if "attributes" not in span:
|
35 |
+
continue
|
36 |
+
|
37 |
+
attributes = span.get("attributes", {})
|
38 |
+
span_kind = attributes.get("openinference.span.kind", "")
|
39 |
+
|
40 |
+
# Collect LLM interactions - look for direct message content first
|
41 |
+
if span_kind == "LLM":
|
42 |
+
# Initialize the LLM info dictionary
|
43 |
+
span_info = {}
|
44 |
+
|
45 |
+
# Try to get input message
|
46 |
+
input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
|
47 |
+
if input_key in attributes:
|
48 |
+
span_info["input"] = attributes[input_key]
|
49 |
+
|
50 |
+
# Try to get output message directly
|
51 |
+
output_content = None
|
52 |
+
# Try in multiple possible locations
|
53 |
+
for key in [
|
54 |
+
"llm.output_messages.0.message.content",
|
55 |
+
"llm.output_messages.0.message.contents.0.message_content.text",
|
56 |
+
]:
|
57 |
+
if key in attributes:
|
58 |
+
output_content = attributes[key]
|
59 |
+
break
|
60 |
+
|
61 |
+
# If we found direct output content, use it
|
62 |
+
if output_content:
|
63 |
+
span_info["output"] = output_content
|
64 |
+
calls.append(span_info)
|
65 |
+
elif span_kind == "TOOL":
|
66 |
+
tool_name = attributes.get("tool.name", "Unknown tool")
|
67 |
+
tool_output = attributes.get("output.value", "")
|
68 |
+
|
69 |
+
span_info = {
|
70 |
+
"tool_name": tool_name,
|
71 |
+
"input": attributes.get("input.value", ""),
|
72 |
+
"output": tool_output,
|
73 |
+
"status": span.get("status", {}).get("status_code"),
|
74 |
+
}
|
75 |
+
span_info["input"] = json.loads(span_info["input"])
|
76 |
+
|
77 |
+
calls.append(span_info)
|
78 |
+
|
79 |
+
return calls
|
80 |
+
|
81 |
+
|
82 |
+
# Backward compatibility functions that use the new class structure
|
83 |
+
def extract_hypothesis_answer(
|
84 |
+
trace: List[Dict[str, Any]], agent_type: AgentType
|
85 |
+
) -> str:
|
86 |
+
"""Extract the hypothesis agent final answer from the trace"""
|
87 |
+
processor = TelemetryProcessor.create(agent_type)
|
88 |
+
return processor.extract_hypothesis_answer(trace)
|
89 |
+
|
90 |
+
|
91 |
+
def parse_generic_key_value_string(text: str) -> Dict[str, str]:
|
92 |
+
"""
|
93 |
+
Parse a string that has items of a dict with key-value pairs separated by '='.
|
94 |
+
Only splits on '=' signs, handling quoted strings properly.
|
95 |
+
"""
|
96 |
+
return TelemetryProcessor.parse_generic_key_value_string(text)
|
97 |
+
|
98 |
+
|
99 |
+
def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
|
100 |
+
"""Extract relevant telemetry evidence based on the agent type."""
|
101 |
+
processor = TelemetryProcessor.create(agent_type)
|
102 |
+
return processor.extract_evidence(telemetry)
|
src/surf_spot_finder/evaluation/telemetry/smolagents_telemetry.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
import json
|
3 |
+
|
4 |
+
from surf_spot_finder.agents import AgentType
|
5 |
+
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
6 |
+
|
7 |
+
|
8 |
+
class SmolagentsTelemetryProcessor(TelemetryProcessor):
|
9 |
+
"""Processor for SmoL Agents telemetry data."""
|
10 |
+
|
11 |
+
def _get_agent_type(self) -> AgentType:
|
12 |
+
return AgentType.SMOLAGENTS
|
13 |
+
|
14 |
+
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
15 |
+
for span in reversed(trace):
|
16 |
+
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
17 |
+
content = span["attributes"]["output.value"]
|
18 |
+
return content
|
19 |
+
|
20 |
+
raise ValueError("No agent final answer found in trace")
|
21 |
+
|
22 |
+
def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
|
23 |
+
"""Extract LLM calls and tool calls from SmoL Agents telemetry."""
|
24 |
+
calls = []
|
25 |
+
|
26 |
+
for span in telemetry:
|
27 |
+
# Skip spans without attributes
|
28 |
+
if "attributes" not in span:
|
29 |
+
continue
|
30 |
+
|
31 |
+
attributes = span["attributes"]
|
32 |
+
|
33 |
+
# Extract tool information
|
34 |
+
if "tool.name" in attributes or span.get("name", "").startswith(
|
35 |
+
"SimpleTool"
|
36 |
+
):
|
37 |
+
tool_info = {
|
38 |
+
"tool_name": attributes.get(
|
39 |
+
"tool.name", span.get("name", "Unknown tool")
|
40 |
+
),
|
41 |
+
"status": "success"
|
42 |
+
if span.get("status", {}).get("status_code") == "OK"
|
43 |
+
else "error",
|
44 |
+
"error": span.get("status", {}).get("description", None),
|
45 |
+
}
|
46 |
+
|
47 |
+
# Extract input if available
|
48 |
+
if "input.value" in attributes:
|
49 |
+
try:
|
50 |
+
input_value = json.loads(attributes["input.value"])
|
51 |
+
if "kwargs" in input_value:
|
52 |
+
# For SmoLAgents, the actual input is often in the kwargs field
|
53 |
+
tool_info["input"] = input_value["kwargs"]
|
54 |
+
else:
|
55 |
+
tool_info["input"] = input_value
|
56 |
+
except (json.JSONDecodeError, TypeError):
|
57 |
+
tool_info["input"] = attributes["input.value"]
|
58 |
+
|
59 |
+
# Extract output if available
|
60 |
+
if "output.value" in attributes:
|
61 |
+
try:
|
62 |
+
# Try to parse JSON output
|
63 |
+
output_value = (
|
64 |
+
json.loads(attributes["output.value"])
|
65 |
+
if isinstance(attributes["output.value"], str)
|
66 |
+
else attributes["output.value"]
|
67 |
+
)
|
68 |
+
tool_info["output"] = output_value
|
69 |
+
except (json.JSONDecodeError, TypeError):
|
70 |
+
tool_info["output"] = attributes["output.value"]
|
71 |
+
else:
|
72 |
+
tool_info["output"] = "No output found"
|
73 |
+
|
74 |
+
calls.append(tool_info)
|
75 |
+
|
76 |
+
# Extract LLM calls to see reasoning
|
77 |
+
elif "LiteLLMModel.__call__" in span.get("name", ""):
|
78 |
+
# The LLM output may be in different places depending on the implementation
|
79 |
+
output_content = None
|
80 |
+
|
81 |
+
# Try to get the output from the llm.output_messages.0.message.content attribute
|
82 |
+
if "llm.output_messages.0.message.content" in attributes:
|
83 |
+
output_content = attributes["llm.output_messages.0.message.content"]
|
84 |
+
|
85 |
+
# Or try to parse it from the output.value as JSON
|
86 |
+
elif "output.value" in attributes:
|
87 |
+
try:
|
88 |
+
output_value = json.loads(attributes["output.value"])
|
89 |
+
if "content" in output_value:
|
90 |
+
output_content = output_value["content"]
|
91 |
+
except (json.JSONDecodeError, TypeError):
|
92 |
+
pass
|
93 |
+
|
94 |
+
if output_content:
|
95 |
+
calls.append(
|
96 |
+
{
|
97 |
+
"model": attributes.get("llm.model_name", "Unknown model"),
|
98 |
+
"output": output_content,
|
99 |
+
"type": "reasoning",
|
100 |
+
}
|
101 |
+
)
|
102 |
+
|
103 |
+
return calls
|
src/surf_spot_finder/evaluation/telemetry/telemetry.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, ClassVar
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from loguru import logger
|
6 |
+
from surf_spot_finder.agents import AgentType
|
7 |
+
|
8 |
+
|
9 |
+
class TelemetryProcessor(ABC):
|
10 |
+
"""Base class for processing telemetry data from different agent types."""
|
11 |
+
|
12 |
+
MAX_EVIDENCE_LENGTH: ClassVar[int] = 400
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def create(cls, agent_type: AgentType) -> "TelemetryProcessor":
|
16 |
+
"""Factory method to create the appropriate telemetry processor."""
|
17 |
+
if agent_type == AgentType.LANGCHAIN:
|
18 |
+
from surf_spot_finder.evaluation.telemetry.langchain_telemetry import (
|
19 |
+
LangchainTelemetryProcessor,
|
20 |
+
)
|
21 |
+
|
22 |
+
return LangchainTelemetryProcessor()
|
23 |
+
elif agent_type == AgentType.SMOLAGENTS:
|
24 |
+
from surf_spot_finder.evaluation.telemetry.smolagents_telemetry import (
|
25 |
+
SmolagentsTelemetryProcessor,
|
26 |
+
)
|
27 |
+
|
28 |
+
return SmolagentsTelemetryProcessor()
|
29 |
+
elif agent_type == AgentType.OPENAI:
|
30 |
+
from surf_spot_finder.evaluation.telemetry.openai_telemetry import (
|
31 |
+
OpenAITelemetryProcessor,
|
32 |
+
)
|
33 |
+
|
34 |
+
return OpenAITelemetryProcessor()
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Unsupported agent type {agent_type}")
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
|
40 |
+
"""Determine the agent type based on the trace.
|
41 |
+
These are not really stable ways to find it, because we're waiting on some
|
42 |
+
reliable method for determining the agent type. This is a temporary solution.
|
43 |
+
"""
|
44 |
+
for span in trace:
|
45 |
+
if "langchain" in span.get("attributes", {}).get("input.value", ""):
|
46 |
+
logger.info("Agent type is LANGCHAIN")
|
47 |
+
return AgentType.LANGCHAIN
|
48 |
+
if span.get("attributes", {}).get("smolagents.max_steps"):
|
49 |
+
logger.info("Agent type is SMOLAGENTS")
|
50 |
+
return AgentType.SMOLAGENTS
|
51 |
+
# This is extremely fragile but there currently isn't
|
52 |
+
# any specific key to indicate the agent type
|
53 |
+
if span.get("name") == "response":
|
54 |
+
logger.info("Agent type is OPENAI")
|
55 |
+
return AgentType.OPENAI
|
56 |
+
raise ValueError(
|
57 |
+
"Could not determine agent type from trace, or agent type not supported"
|
58 |
+
)
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def extract_hypothesis_answer(self, trace: List[Dict[str, Any]]) -> str:
|
62 |
+
"""Extract the hypothesis agent final answer from the trace."""
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abstractmethod
|
66 |
+
def _extract_telemetry_data(self, telemetry: List[Dict[str, Any]]) -> List[Dict]:
|
67 |
+
"""Extract the agent-specific data from telemetry."""
|
68 |
+
pass
|
69 |
+
|
70 |
+
def extract_evidence(self, telemetry: List[Dict[str, Any]]) -> str:
|
71 |
+
"""Extract relevant telemetry evidence."""
|
72 |
+
calls = self._extract_telemetry_data(telemetry)
|
73 |
+
return self._format_evidence(calls)
|
74 |
+
|
75 |
+
def _format_evidence(self, calls: List[Dict]) -> str:
|
76 |
+
"""Format extracted data into a standardized output format."""
|
77 |
+
evidence = f"## {self._get_agent_type().name} Agent Execution\n\n"
|
78 |
+
|
79 |
+
for idx, call in enumerate(calls, start=1):
|
80 |
+
evidence += f"### Call {idx}\n"
|
81 |
+
|
82 |
+
# Truncate any values that are too long
|
83 |
+
call = {
|
84 |
+
k: (
|
85 |
+
v[: self.MAX_EVIDENCE_LENGTH] + "..."
|
86 |
+
if isinstance(v, str) and len(v) > self.MAX_EVIDENCE_LENGTH
|
87 |
+
else v
|
88 |
+
)
|
89 |
+
for k, v in call.items()
|
90 |
+
}
|
91 |
+
|
92 |
+
# Use ensure_ascii=False to prevent escaping Unicode characters
|
93 |
+
evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
|
94 |
+
|
95 |
+
return evidence
|
96 |
+
|
97 |
+
@abstractmethod
|
98 |
+
def _get_agent_type(self) -> AgentType:
|
99 |
+
"""Get the agent type associated with this processor."""
|
100 |
+
pass
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def parse_generic_key_value_string(text: str) -> Dict[str, str]:
|
104 |
+
"""
|
105 |
+
Parse a string that has items of a dict with key-value pairs separated by '='.
|
106 |
+
Only splits on '=' signs, handling quoted strings properly.
|
107 |
+
"""
|
108 |
+
pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
|
109 |
+
result = {}
|
110 |
+
|
111 |
+
matches = re.findall(pattern, text)
|
112 |
+
for key, value in matches:
|
113 |
+
# Clean up the key
|
114 |
+
key = key.strip()
|
115 |
+
|
116 |
+
# Clean up the value - remove surrounding quotes if present
|
117 |
+
if (value.startswith("'") and value.endswith("'")) or (
|
118 |
+
value.startswith('"') and value.endswith('"')
|
119 |
+
):
|
120 |
+
value = value[1:-1]
|
121 |
+
|
122 |
+
# Store in result dictionary
|
123 |
+
result[key] = value
|
124 |
+
|
125 |
+
return result
|
src/surf_spot_finder/evaluation/telemetry_utils.py
DELETED
@@ -1,301 +0,0 @@
|
|
1 |
-
from typing import Any, Dict, List
|
2 |
-
import json
|
3 |
-
from langchain_core.messages import BaseMessage
|
4 |
-
import re
|
5 |
-
|
6 |
-
from surf_spot_finder.agents import AgentType
|
7 |
-
|
8 |
-
|
9 |
-
def extract_hypothesis_answer(
|
10 |
-
trace: List[Dict[str, Any]], agent_type: AgentType
|
11 |
-
) -> str:
|
12 |
-
"""Extract the hypothesis agent final answer from the trace"""
|
13 |
-
for span in reversed(trace):
|
14 |
-
if agent_type == AgentType.LANGCHAIN:
|
15 |
-
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
16 |
-
content = span["attributes"]["output.value"]
|
17 |
-
# If it's langchain, the actual content is a serialized langchain message that we need to extract.
|
18 |
-
message = json.loads(content)["messages"][0]
|
19 |
-
message = parse_generic_key_value_string(message)
|
20 |
-
base_message = BaseMessage(**message, type="AGENT")
|
21 |
-
print(base_message.text())
|
22 |
-
return base_message.text()
|
23 |
-
elif agent_type == AgentType.SMOLAGENTS:
|
24 |
-
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
25 |
-
content = span["attributes"]["output.value"]
|
26 |
-
# If it's langchain, the actual content is a serialized langchain message that we need to extract.
|
27 |
-
return content
|
28 |
-
elif agent_type == AgentType.OPENAI:
|
29 |
-
# Looking for the final response that has the summary answer
|
30 |
-
if (
|
31 |
-
"attributes" in span
|
32 |
-
and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
|
33 |
-
):
|
34 |
-
output_key = (
|
35 |
-
"llm.output_messages.0.message.contents.0.message_content.text"
|
36 |
-
)
|
37 |
-
if output_key in span["attributes"]:
|
38 |
-
return span["attributes"][output_key]
|
39 |
-
else:
|
40 |
-
raise ValueError(f"Unsupported agent type {agent_type}")
|
41 |
-
raise ValueError("No agent final answer found in trace")
|
42 |
-
|
43 |
-
|
44 |
-
def parse_generic_key_value_string(text):
|
45 |
-
"""
|
46 |
-
Parse a string that has items of a dict with key-value pairs separated by '='.
|
47 |
-
Only splits on '=' signs, handling quoted strings properly.
|
48 |
-
I think this is to compensate for a bug in openinference? https://github.com/Arize-ai/openinference/issues/1401
|
49 |
-
"""
|
50 |
-
|
51 |
-
# Pattern to match key=value pairs, handling quoted values
|
52 |
-
# This regex looks for word characters followed by = and then captures everything
|
53 |
-
# until it finds another word character followed by = or the end of the string
|
54 |
-
# Claude helped me with this one, regex is hard
|
55 |
-
pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
|
56 |
-
|
57 |
-
result = {}
|
58 |
-
|
59 |
-
matches = re.findall(pattern, text)
|
60 |
-
for key, value in matches:
|
61 |
-
# Clean up the key
|
62 |
-
key = key.strip()
|
63 |
-
|
64 |
-
# Clean up the value - remove surrounding quotes if present
|
65 |
-
if (value.startswith("'") and value.endswith("'")) or (
|
66 |
-
value.startswith('"') and value.endswith('"')
|
67 |
-
):
|
68 |
-
value = value[1:-1]
|
69 |
-
|
70 |
-
# Store in result dictionary
|
71 |
-
result[key] = value
|
72 |
-
|
73 |
-
return result
|
74 |
-
|
75 |
-
|
76 |
-
def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
|
77 |
-
"""Extract relevant telemetry evidence based on the agent type."""
|
78 |
-
# Data extraction function for each agent type
|
79 |
-
extractors = {
|
80 |
-
AgentType.SMOLAGENTS: _extract_smolagents_data,
|
81 |
-
AgentType.LANGCHAIN: _extract_langchain_data,
|
82 |
-
AgentType.OPENAI: _extract_openai_data,
|
83 |
-
}
|
84 |
-
|
85 |
-
if agent_type not in extractors:
|
86 |
-
raise ValueError(f"Unsupported agent type {agent_type}")
|
87 |
-
|
88 |
-
# Extract raw data from telemetry
|
89 |
-
calls = extractors[agent_type](telemetry)
|
90 |
-
|
91 |
-
# Format data into a consistent structure
|
92 |
-
return _format_evidence(calls, agent_type)
|
93 |
-
|
94 |
-
|
95 |
-
def _extract_smolagents_data(telemetry: List[Dict[str, Any]]) -> List[Dict]:
|
96 |
-
"""Extract LLM calls and tool calls from SmoL Agents telemetry."""
|
97 |
-
calls = []
|
98 |
-
|
99 |
-
for span in telemetry:
|
100 |
-
# Skip spans without attributes
|
101 |
-
if "attributes" not in span:
|
102 |
-
continue
|
103 |
-
|
104 |
-
attributes = span["attributes"]
|
105 |
-
|
106 |
-
# Extract tool information
|
107 |
-
if "tool.name" in attributes or span.get("name", "").startswith("SimpleTool"):
|
108 |
-
tool_info = {
|
109 |
-
"tool_name": attributes.get(
|
110 |
-
"tool.name", span.get("name", "Unknown tool")
|
111 |
-
),
|
112 |
-
"status": "success"
|
113 |
-
if span.get("status", {}).get("status_code") == "OK"
|
114 |
-
else "error",
|
115 |
-
"error": span.get("status", {}).get("description", None),
|
116 |
-
}
|
117 |
-
|
118 |
-
# Extract input if available
|
119 |
-
if "input.value" in attributes:
|
120 |
-
try:
|
121 |
-
input_value = json.loads(attributes["input.value"])
|
122 |
-
if "kwargs" in input_value:
|
123 |
-
# For SmoLAgents, the actual input is often in the kwargs field
|
124 |
-
tool_info["input"] = input_value["kwargs"]
|
125 |
-
else:
|
126 |
-
tool_info["input"] = input_value
|
127 |
-
except (json.JSONDecodeError, TypeError):
|
128 |
-
tool_info["input"] = attributes["input.value"]
|
129 |
-
|
130 |
-
# Extract output if available
|
131 |
-
if "output.value" in attributes:
|
132 |
-
try:
|
133 |
-
# Try to parse JSON output
|
134 |
-
output_value = (
|
135 |
-
json.loads(attributes["output.value"])
|
136 |
-
if isinstance(attributes["output.value"], str)
|
137 |
-
else attributes["output.value"]
|
138 |
-
)
|
139 |
-
tool_info["output"] = output_value
|
140 |
-
except (json.JSONDecodeError, TypeError):
|
141 |
-
tool_info["output"] = attributes["output.value"]
|
142 |
-
else:
|
143 |
-
tool_info["output"] = "No output found"
|
144 |
-
|
145 |
-
calls.append(tool_info)
|
146 |
-
|
147 |
-
# Extract LLM calls to see reasoning
|
148 |
-
elif "LiteLLMModel.__call__" in span.get("name", ""):
|
149 |
-
# The LLM output may be in different places depending on the implementation
|
150 |
-
output_content = None
|
151 |
-
|
152 |
-
# Try to get the output from the llm.output_messages.0.message.content attribute
|
153 |
-
if "llm.output_messages.0.message.content" in attributes:
|
154 |
-
output_content = attributes["llm.output_messages.0.message.content"]
|
155 |
-
|
156 |
-
# Or try to parse it from the output.value as JSON
|
157 |
-
elif "output.value" in attributes:
|
158 |
-
try:
|
159 |
-
output_value = json.loads(attributes["output.value"])
|
160 |
-
if "content" in output_value:
|
161 |
-
output_content = output_value["content"]
|
162 |
-
except (json.JSONDecodeError, TypeError):
|
163 |
-
pass
|
164 |
-
|
165 |
-
if output_content:
|
166 |
-
calls.append(
|
167 |
-
{
|
168 |
-
"model": attributes.get("llm.model_name", "Unknown model"),
|
169 |
-
"output": output_content,
|
170 |
-
"type": "reasoning",
|
171 |
-
}
|
172 |
-
)
|
173 |
-
|
174 |
-
return calls
|
175 |
-
|
176 |
-
|
177 |
-
def _extract_langchain_data(telemetry: List[Dict[str, Any]]) -> List:
|
178 |
-
"""Extract LLM calls and tool calls from LangChain telemetry."""
|
179 |
-
calls = []
|
180 |
-
|
181 |
-
for span in telemetry:
|
182 |
-
if "attributes" not in span:
|
183 |
-
continue
|
184 |
-
|
185 |
-
attributes = span.get("attributes", {})
|
186 |
-
span_kind = attributes.get("openinference.span.kind", "")
|
187 |
-
|
188 |
-
# Collect LLM calls
|
189 |
-
if span_kind == "LLM" and "llm.output_messages.0.message.content" in attributes:
|
190 |
-
llm_info = {
|
191 |
-
"model": attributes.get("llm.model_name", "Unknown model"),
|
192 |
-
"input": attributes.get("llm.input_messages.0.message.content", ""),
|
193 |
-
"output": attributes.get("llm.output_messages.0.message.content", ""),
|
194 |
-
"type": "reasoning",
|
195 |
-
}
|
196 |
-
calls.append(llm_info)
|
197 |
-
|
198 |
-
# Try to find tool calls
|
199 |
-
if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
|
200 |
-
tool_info = {
|
201 |
-
"tool_name": attributes.get(
|
202 |
-
"tool.name", span.get("name", "Unknown tool")
|
203 |
-
),
|
204 |
-
"status": "success"
|
205 |
-
if span.get("status", {}).get("status_code") == "OK"
|
206 |
-
else "error",
|
207 |
-
"error": span.get("status", {}).get("description", None),
|
208 |
-
}
|
209 |
-
|
210 |
-
if "input.value" in attributes:
|
211 |
-
try:
|
212 |
-
input_value = json.loads(attributes["input.value"])
|
213 |
-
tool_info["input"] = input_value
|
214 |
-
except Exception:
|
215 |
-
tool_info["input"] = attributes["input.value"]
|
216 |
-
|
217 |
-
if "output.value" in attributes:
|
218 |
-
tool_info["output"] = parse_generic_key_value_string(
|
219 |
-
json.loads(attributes["output.value"])["output"]
|
220 |
-
)["content"]
|
221 |
-
|
222 |
-
calls.append(tool_info)
|
223 |
-
|
224 |
-
return calls
|
225 |
-
|
226 |
-
|
227 |
-
def _extract_openai_data(telemetry: List[Dict[str, Any]]) -> list:
|
228 |
-
"""Extract LLM calls and tool calls from OpenAI telemetry."""
|
229 |
-
calls = []
|
230 |
-
|
231 |
-
for span in telemetry:
|
232 |
-
if "attributes" not in span:
|
233 |
-
continue
|
234 |
-
|
235 |
-
attributes = span.get("attributes", {})
|
236 |
-
span_kind = attributes.get("openinference.span.kind", "")
|
237 |
-
|
238 |
-
# Collect LLM interactions - look for direct message content first
|
239 |
-
if span_kind == "LLM":
|
240 |
-
# Initialize the LLM info dictionary
|
241 |
-
span_info = {}
|
242 |
-
|
243 |
-
# Try to get input message
|
244 |
-
input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
|
245 |
-
if input_key in attributes:
|
246 |
-
span_info["input"] = attributes[input_key]
|
247 |
-
|
248 |
-
# Try to get output message directly
|
249 |
-
output_content = None
|
250 |
-
# Try in multiple possible locations
|
251 |
-
for key in [
|
252 |
-
"llm.output_messages.0.message.content",
|
253 |
-
"llm.output_messages.0.message.contents.0.message_content.text",
|
254 |
-
]:
|
255 |
-
if key in attributes:
|
256 |
-
output_content = attributes[key]
|
257 |
-
break
|
258 |
-
|
259 |
-
# If we found direct output content, use it
|
260 |
-
if output_content:
|
261 |
-
span_info["output"] = output_content
|
262 |
-
calls.append(span_info)
|
263 |
-
elif span_kind == "TOOL":
|
264 |
-
tool_name = attributes.get("tool.name", "Unknown tool")
|
265 |
-
tool_output = attributes.get("output.value", "")
|
266 |
-
|
267 |
-
span_info = {
|
268 |
-
"tool_name": tool_name,
|
269 |
-
"input": attributes.get("input.value", ""),
|
270 |
-
"output": tool_output,
|
271 |
-
"status": span.get("status", {}).get("status_code"),
|
272 |
-
}
|
273 |
-
span_info["input"] = json.loads(span_info["input"])
|
274 |
-
|
275 |
-
calls.append(span_info)
|
276 |
-
|
277 |
-
return calls
|
278 |
-
|
279 |
-
|
280 |
-
def _format_evidence(calls: List[Dict], agent_type: AgentType) -> str:
|
281 |
-
"""Format extracted data into a standardized output format."""
|
282 |
-
evidence = f"## {agent_type.name} Agent Execution\n\n"
|
283 |
-
|
284 |
-
for idx, call in enumerate(calls, start=1):
|
285 |
-
evidence += f"### Call {idx}\n"
|
286 |
-
|
287 |
-
# Truncate any values that are too long
|
288 |
-
max_length = 400
|
289 |
-
call = {
|
290 |
-
k: (
|
291 |
-
v[:max_length] + "..."
|
292 |
-
if isinstance(v, str) and len(v) > max_length
|
293 |
-
else v
|
294 |
-
)
|
295 |
-
for k, v in call.items()
|
296 |
-
}
|
297 |
-
|
298 |
-
# Use ensure_ascii=False to prevent escaping Unicode characters
|
299 |
-
evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
|
300 |
-
|
301 |
-
return evidence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/surf_spot_finder/evaluation/test_case.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Dict, List,
|
2 |
from pydantic import BaseModel, Field, ConfigDict
|
3 |
import yaml
|
4 |
|
@@ -13,14 +13,6 @@ class InputModel(BaseModel):
|
|
13 |
json_tracer: bool
|
14 |
|
15 |
|
16 |
-
class AgentModel(BaseModel):
|
17 |
-
model_id: str
|
18 |
-
api_key_var: str = "OPENAI_API_KEY"
|
19 |
-
api_base: Optional[str] = None
|
20 |
-
agent_type: str
|
21 |
-
tools: Optional[List[str]] = None
|
22 |
-
|
23 |
-
|
24 |
class CheckpointCriteria(BaseModel):
|
25 |
"""Represents a checkpoint criteria with a description"""
|
26 |
|
@@ -32,20 +24,15 @@ class CheckpointCriteria(BaseModel):
|
|
32 |
class TestCase(BaseModel):
|
33 |
model_config = ConfigDict(extra="forbid")
|
34 |
input: InputModel
|
35 |
-
agent: AgentModel
|
36 |
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
|
37 |
checkpoints: List[CheckpointCriteria] = Field(default_factory=list)
|
38 |
final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
|
39 |
|
40 |
@classmethod
|
41 |
-
def from_yaml(cls, test_case_path: str
|
42 |
"""Load a test case from a YAML file and process it"""
|
43 |
with open(test_case_path, "r") as f:
|
44 |
test_case_dict = yaml.safe_load(f)
|
45 |
-
|
46 |
-
with open(agent_config_path, "r") as f:
|
47 |
-
agent_config_dict = yaml.safe_load(f)
|
48 |
-
test_case_dict["agent"] = agent_config_dict["agent"]
|
49 |
final_answer_criteria = []
|
50 |
|
51 |
def add_gt_final_answer_criteria(ground_truth_list):
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
from pydantic import BaseModel, Field, ConfigDict
|
3 |
import yaml
|
4 |
|
|
|
13 |
json_tracer: bool
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class CheckpointCriteria(BaseModel):
|
17 |
"""Represents a checkpoint criteria with a description"""
|
18 |
|
|
|
24 |
class TestCase(BaseModel):
|
25 |
model_config = ConfigDict(extra="forbid")
|
26 |
input: InputModel
|
|
|
27 |
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
|
28 |
checkpoints: List[CheckpointCriteria] = Field(default_factory=list)
|
29 |
final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
|
30 |
|
31 |
@classmethod
|
32 |
+
def from_yaml(cls, test_case_path: str) -> "TestCase":
|
33 |
"""Load a test case from a YAML file and process it"""
|
34 |
with open(test_case_path, "r") as f:
|
35 |
test_case_dict = yaml.safe_load(f)
|
|
|
|
|
|
|
|
|
36 |
final_answer_criteria = []
|
37 |
|
38 |
def add_gt_final_answer_criteria(ground_truth_list):
|
src/surf_spot_finder/evaluation/utils.py
CHANGED
@@ -4,36 +4,11 @@ import re
|
|
4 |
|
5 |
from litellm import completion
|
6 |
from textwrap import dedent
|
7 |
-
from loguru import logger
|
8 |
|
9 |
from pydantic import BaseModel, ConfigDict
|
10 |
-
from surf_spot_finder.evaluation.
|
11 |
from surf_spot_finder.evaluation.test_case import CheckpointCriteria
|
12 |
|
13 |
-
from surf_spot_finder.agents import AgentType
|
14 |
-
|
15 |
-
|
16 |
-
def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
|
17 |
-
"""Determine the agent type based on the trace.
|
18 |
-
These are not really stable ways to find it, because we're waiting on some
|
19 |
-
reliable method for determining the agent type. This is a temporary solution.
|
20 |
-
"""
|
21 |
-
for span in trace:
|
22 |
-
if "langchain" in span.get("attributes", {}).get("input.value", ""):
|
23 |
-
logger.info("Agent type is LANGCHAIN")
|
24 |
-
return AgentType.LANGCHAIN
|
25 |
-
if span.get("attributes", {}).get("smolagents.max_steps"):
|
26 |
-
logger.info("Agent type is SMOLAGENTS")
|
27 |
-
return AgentType.SMOLAGENTS
|
28 |
-
# This is extremely fragile but there currently isn't
|
29 |
-
# any specific key to indicate the agent type
|
30 |
-
if span.get("name") == "response":
|
31 |
-
logger.info("Agent type is OPENAI")
|
32 |
-
return AgentType.OPENAI
|
33 |
-
raise ValueError(
|
34 |
-
"Could not determine agent type from trace, or agent type not supported"
|
35 |
-
)
|
36 |
-
|
37 |
|
38 |
class EvaluationResult(BaseModel):
|
39 |
"""Represents the result of evaluating a criterion"""
|
@@ -126,7 +101,7 @@ def verify_checkpoints(
|
|
126 |
telemetry: List[Dict[str, Any]],
|
127 |
checkpoints: List[CheckpointCriteria],
|
128 |
model: str,
|
129 |
-
|
130 |
) -> List[EvaluationResult]:
|
131 |
"""Verify each checkpoint against the telemetry data using LLM
|
132 |
These checkpoints do not take the ground truth or hyupothesis
|
@@ -134,8 +109,7 @@ def verify_checkpoints(
|
|
134 |
the specific criteria mentioned.
|
135 |
"""
|
136 |
results = []
|
137 |
-
|
138 |
-
evidence = extract_evidence(telemetry, agent_type)
|
139 |
print(evidence)
|
140 |
for checkpoint in checkpoints:
|
141 |
criteria = checkpoint.criteria
|
|
|
4 |
|
5 |
from litellm import completion
|
6 |
from textwrap import dedent
|
|
|
7 |
|
8 |
from pydantic import BaseModel, ConfigDict
|
9 |
+
from surf_spot_finder.evaluation.telemetry import TelemetryProcessor
|
10 |
from surf_spot_finder.evaluation.test_case import CheckpointCriteria
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class EvaluationResult(BaseModel):
|
14 |
"""Represents the result of evaluating a criterion"""
|
|
|
101 |
telemetry: List[Dict[str, Any]],
|
102 |
checkpoints: List[CheckpointCriteria],
|
103 |
model: str,
|
104 |
+
processor: TelemetryProcessor,
|
105 |
) -> List[EvaluationResult]:
|
106 |
"""Verify each checkpoint against the telemetry data using LLM
|
107 |
These checkpoints do not take the ground truth or hyupothesis
|
|
|
109 |
the specific criteria mentioned.
|
110 |
"""
|
111 |
results = []
|
112 |
+
evidence = processor.extract_evidence(telemetry)
|
|
|
113 |
print(evidence)
|
114 |
for checkpoint in checkpoints:
|
115 |
criteria = checkpoint.criteria
|