Spaces:
Running
Running
Nathan Brake
commited on
Uniform Trace extraction and inferring agent_type (#27)
Browse files* Need to re-tool the eval files but at least now evidence extraction is possible
* Linting
* linting
* fix unit tests
* lint
* patch langchain and openai output format
* lint
* fix pyproj
- pyproject.toml +2 -1
- src/surf_spot_finder/agents/__init__.py +22 -7
- src/surf_spot_finder/agents/langchain.py +7 -4
- src/surf_spot_finder/agents/openai.py +8 -2
- src/surf_spot_finder/cli.py +6 -3
- src/surf_spot_finder/evaluation/evaluate.py +21 -28
- src/surf_spot_finder/evaluation/telemetry_utils.py +301 -0
- src/surf_spot_finder/evaluation/test_case.py +1 -0
- src/surf_spot_finder/evaluation/test_cases/alpha.yaml +48 -13
- src/surf_spot_finder/evaluation/utils.py +29 -69
- src/surf_spot_finder/tracing.py +21 -17
- tests/unit/test_unit_tracing.py +3 -7
pyproject.toml
CHANGED
@@ -20,6 +20,7 @@ dependencies = [
|
|
20 |
langchain = [
|
21 |
"langchain",
|
22 |
"langgraph",
|
|
|
23 |
"openinference-instrumentation-langchain"
|
24 |
]
|
25 |
smolagents = [
|
@@ -29,7 +30,7 @@ smolagents = [
|
|
29 |
|
30 |
openai = [
|
31 |
"openai-agents",
|
32 |
-
"openinference-instrumentation-openai-agents"
|
33 |
]
|
34 |
|
35 |
mcp = [
|
|
|
20 |
langchain = [
|
21 |
"langchain",
|
22 |
"langgraph",
|
23 |
+
"langchain-openai>=0.3.9",
|
24 |
"openinference-instrumentation-langchain"
|
25 |
]
|
26 |
smolagents = [
|
|
|
30 |
|
31 |
openai = [
|
32 |
"openai-agents",
|
33 |
+
"openinference-instrumentation-openai-agents>=0.1.2"
|
34 |
]
|
35 |
|
36 |
mcp = [
|
src/surf_spot_finder/agents/__init__.py
CHANGED
@@ -1,15 +1,30 @@
|
|
|
|
1 |
from .langchain import run_lanchain_agent
|
2 |
from .openai import run_openai_agent, run_openai_multi_agent
|
3 |
from .smolagents import run_smolagent
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
RUNNERS = {
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
}
|
11 |
|
12 |
|
13 |
-
def validate_agent_type(value) -> str:
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
from .langchain import run_lanchain_agent
|
3 |
from .openai import run_openai_agent, run_openai_multi_agent
|
4 |
from .smolagents import run_smolagent
|
5 |
|
6 |
+
|
7 |
+
# Define the available agent type enums
|
8 |
+
class AgentType(str, Enum):
|
9 |
+
LANGCHAIN = "langchain"
|
10 |
+
OPENAI = "openai"
|
11 |
+
OPENAI_MULTI_AGENT = "openai_multi_agent"
|
12 |
+
SMOLAGENTS = "smolagents"
|
13 |
+
|
14 |
+
|
15 |
RUNNERS = {
|
16 |
+
AgentType.LANGCHAIN: run_lanchain_agent,
|
17 |
+
AgentType.OPENAI: run_openai_agent,
|
18 |
+
AgentType.SMOLAGENTS: run_smolagent,
|
19 |
+
AgentType.OPENAI_MULTI_AGENT: run_openai_multi_agent,
|
20 |
}
|
21 |
|
22 |
|
23 |
+
def validate_agent_type(value: str) -> str:
|
24 |
+
try:
|
25 |
+
agent_type = AgentType(value)
|
26 |
+
if agent_type not in RUNNERS:
|
27 |
+
raise ValueError(f"agent_type {value} is valid but has no runner")
|
28 |
+
return value
|
29 |
+
except ValueError:
|
30 |
+
raise ValueError(f"agent_type must be one of {[e.value for e in AgentType]}")
|
src/surf_spot_finder/agents/langchain.py
CHANGED
@@ -14,6 +14,8 @@ try:
|
|
14 |
except ImportError:
|
15 |
langchain_available = False
|
16 |
|
|
|
|
|
17 |
|
18 |
@logger.catch(reraise=True)
|
19 |
def run_lanchain_agent(
|
@@ -52,13 +54,14 @@ def run_lanchain_agent(
|
|
52 |
|
53 |
model = init_chat_model(model_id)
|
54 |
agent = create_react_agent(
|
55 |
-
model=model,
|
56 |
-
tools=imported_tools,
|
57 |
-
checkpointer=MemorySaver(),
|
58 |
)
|
59 |
for step in agent.stream(
|
60 |
{"messages": [HumanMessage(content=prompt)]},
|
61 |
-
{
|
|
|
|
|
|
|
62 |
stream_mode="values",
|
63 |
):
|
64 |
step["messages"][-1].pretty_print()
|
|
|
14 |
except ImportError:
|
15 |
langchain_available = False
|
16 |
|
17 |
+
DEFAULT_RECURSION_LIMIT = 50
|
18 |
+
|
19 |
|
20 |
@logger.catch(reraise=True)
|
21 |
def run_lanchain_agent(
|
|
|
54 |
|
55 |
model = init_chat_model(model_id)
|
56 |
agent = create_react_agent(
|
57 |
+
model=model, tools=imported_tools, checkpointer=MemorySaver()
|
|
|
|
|
58 |
)
|
59 |
for step in agent.stream(
|
60 |
{"messages": [HumanMessage(content=prompt)]},
|
61 |
+
{
|
62 |
+
"configurable": {"thread_id": "abc123"},
|
63 |
+
"recursion_limit": DEFAULT_RECURSION_LIMIT,
|
64 |
+
},
|
65 |
stream_mode="values",
|
66 |
):
|
67 |
step["messages"][-1].pretty_print()
|
src/surf_spot_finder/agents/openai.py
CHANGED
@@ -24,6 +24,8 @@ try:
|
|
24 |
except ImportError:
|
25 |
agents_available = None
|
26 |
|
|
|
|
|
27 |
|
28 |
@logger.catch(reraise=True)
|
29 |
def run_openai_agent(
|
@@ -34,6 +36,7 @@ def run_openai_agent(
|
|
34 |
api_key_var: Optional[str] = None,
|
35 |
api_base: Optional[str] = None,
|
36 |
tools: Optional[list[str]] = None,
|
|
|
37 |
) -> RunResult:
|
38 |
"""Runs an OpenAI agent with the given prompt and configuration.
|
39 |
|
@@ -94,7 +97,7 @@ def run_openai_agent(
|
|
94 |
name=name,
|
95 |
tools=imported_tools,
|
96 |
)
|
97 |
-
result = Runner.run_sync(agent, prompt)
|
98 |
logger.info(result.final_output)
|
99 |
return result
|
100 |
|
@@ -105,6 +108,7 @@ def run_openai_multi_agent(
|
|
105 |
prompt: str,
|
106 |
name: str = "surf-spot-finder",
|
107 |
instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
|
|
|
108 |
**kwargs,
|
109 |
) -> RunResult:
|
110 |
"""Runs multiple OpenAI agents orchestrated by a main agent.
|
@@ -176,6 +180,8 @@ def run_openai_multi_agent(
|
|
176 |
],
|
177 |
)
|
178 |
|
179 |
-
result = Runner.run_sync(
|
|
|
|
|
180 |
logger.info(result.final_output)
|
181 |
return result
|
|
|
24 |
except ImportError:
|
25 |
agents_available = None
|
26 |
|
27 |
+
DEFAULT_MAX_TURNS = 20
|
28 |
+
|
29 |
|
30 |
@logger.catch(reraise=True)
|
31 |
def run_openai_agent(
|
|
|
36 |
api_key_var: Optional[str] = None,
|
37 |
api_base: Optional[str] = None,
|
38 |
tools: Optional[list[str]] = None,
|
39 |
+
max_turns: Optional[int] = DEFAULT_MAX_TURNS,
|
40 |
) -> RunResult:
|
41 |
"""Runs an OpenAI agent with the given prompt and configuration.
|
42 |
|
|
|
97 |
name=name,
|
98 |
tools=imported_tools,
|
99 |
)
|
100 |
+
result = Runner.run_sync(starting_agent=agent, input=prompt, max_turns=max_turns)
|
101 |
logger.info(result.final_output)
|
102 |
return result
|
103 |
|
|
|
108 |
prompt: str,
|
109 |
name: str = "surf-spot-finder",
|
110 |
instructions: Optional[str] = MULTI_AGENT_SYSTEM_PROMPT,
|
111 |
+
max_turns: Optional[int] = DEFAULT_MAX_TURNS,
|
112 |
**kwargs,
|
113 |
) -> RunResult:
|
114 |
"""Runs multiple OpenAI agents orchestrated by a main agent.
|
|
|
180 |
],
|
181 |
)
|
182 |
|
183 |
+
result = Runner.run_sync(
|
184 |
+
starting_agent=main_agent, input=prompt, max_turns=max_turns
|
185 |
+
)
|
186 |
logger.info(result.final_output)
|
187 |
return result
|
src/surf_spot_finder/cli.py
CHANGED
@@ -26,7 +26,7 @@ def find_surf_spot(
|
|
26 |
api_base: Optional[str] = None,
|
27 |
tools: Optional[list[dict]] = None,
|
28 |
from_config: Optional[str] = None,
|
29 |
-
):
|
30 |
"""Find the best surf spot based on the given criteria.
|
31 |
|
32 |
Args:
|
@@ -71,8 +71,10 @@ def find_surf_spot(
|
|
71 |
)
|
72 |
|
73 |
logger.info("Setting up tracing")
|
74 |
-
tracer_provider,
|
75 |
-
project_name="surf-spot-finder",
|
|
|
|
|
76 |
)
|
77 |
setup_tracing(tracer_provider, config.agent_type)
|
78 |
|
@@ -88,6 +90,7 @@ def find_surf_spot(
|
|
88 |
api_key_var=config.api_key_var,
|
89 |
tools=config.tools,
|
90 |
)
|
|
|
91 |
|
92 |
|
93 |
def main():
|
|
|
26 |
api_base: Optional[str] = None,
|
27 |
tools: Optional[list[dict]] = None,
|
28 |
from_config: Optional[str] = None,
|
29 |
+
) -> str:
|
30 |
"""Find the best surf spot based on the given criteria.
|
31 |
|
32 |
Args:
|
|
|
71 |
)
|
72 |
|
73 |
logger.info("Setting up tracing")
|
74 |
+
tracer_provider, tracing_path = get_tracer_provider(
|
75 |
+
project_name="surf-spot-finder",
|
76 |
+
json_tracer=config.json_tracer,
|
77 |
+
agent_type=config.agent_type,
|
78 |
)
|
79 |
setup_tracing(tracer_provider, config.agent_type)
|
80 |
|
|
|
90 |
api_key_var=config.api_key_var,
|
91 |
tools=config.tools,
|
92 |
)
|
93 |
+
return tracing_path
|
94 |
|
95 |
|
96 |
def main():
|
src/surf_spot_finder/evaluation/evaluate.py
CHANGED
@@ -4,17 +4,17 @@ from textwrap import dedent
|
|
4 |
from typing import Any, Dict, List, Optional
|
5 |
from loguru import logger
|
6 |
from fire import Fire
|
7 |
-
from surf_spot_finder.
|
8 |
from surf_spot_finder.config import (
|
9 |
-
DEFAULT_PROMPT,
|
10 |
Config,
|
11 |
)
|
12 |
-
from surf_spot_finder.
|
13 |
from surf_spot_finder.evaluation.utils import (
|
14 |
-
|
15 |
verify_checkpoints,
|
16 |
verify_hypothesis_answer,
|
17 |
)
|
|
|
18 |
from surf_spot_finder.evaluation.test_case import TestCase
|
19 |
|
20 |
logger.remove()
|
@@ -31,31 +31,23 @@ def run_agent(test_case: TestCase) -> str:
|
|
31 |
max_driving_hours=input_data.max_driving_hours,
|
32 |
model_id=input_data.model_id,
|
33 |
api_key_var=input_data.api_key_var,
|
34 |
-
prompt=
|
35 |
json_tracer=input_data.json_tracer,
|
36 |
api_base=input_data.api_base,
|
37 |
agent_type=input_data.agent_type,
|
|
|
38 |
)
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
project_name=project_name, json_tracer=config.json_tracer
|
45 |
-
)
|
46 |
-
setup_tracing(tracer_provider, agent_type=config.agent_type)
|
47 |
-
logger.info("Running agent")
|
48 |
-
run_smolagent(
|
49 |
model_id=config.model_id,
|
50 |
api_key_var=config.api_key_var,
|
|
|
51 |
api_base=config.api_base,
|
52 |
-
|
53 |
-
LOCATION=config.location,
|
54 |
-
MAX_DRIVING_HOURS=config.max_driving_hours,
|
55 |
-
DATE=config.date,
|
56 |
-
),
|
57 |
)
|
58 |
-
return telemetry_path
|
59 |
|
60 |
|
61 |
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
@@ -64,8 +56,12 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
64 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
65 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
66 |
|
|
|
|
|
67 |
# Extract the final answer from the telemetry
|
68 |
-
hypothesis_answer = extract_hypothesis_answer(
|
|
|
|
|
69 |
logger.info(
|
70 |
f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
|
71 |
)
|
@@ -75,6 +71,7 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
75 |
telemetry=telemetry,
|
76 |
checkpoints=test_case.checkpoints,
|
77 |
model=llm_judge,
|
|
|
78 |
)
|
79 |
|
80 |
hypothesis_answer_results = verify_hypothesis_answer(
|
@@ -110,12 +107,8 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
110 |
logger.error(message)
|
111 |
else:
|
112 |
logger.info("<green>All checkpoints passed!</green>")
|
113 |
-
logger.info(
|
114 |
-
|
115 |
-
)
|
116 |
-
logger.info(
|
117 |
-
f"<red>Failed checkpoints: {len(failed_checks)}/{len(verification_results)}</red>"
|
118 |
-
)
|
119 |
logger.info("<green>=====================================</green>")
|
120 |
logger.info(f"<green>Score: {won_points}/{won_points + missed_points}</green>")
|
121 |
logger.info("<green>=====================================</green>")
|
|
|
4 |
from typing import Any, Dict, List, Optional
|
5 |
from loguru import logger
|
6 |
from fire import Fire
|
7 |
+
from surf_spot_finder.cli import find_surf_spot
|
8 |
from surf_spot_finder.config import (
|
|
|
9 |
Config,
|
10 |
)
|
11 |
+
from surf_spot_finder.prompts.shared import INPUT_PROMPT
|
12 |
from surf_spot_finder.evaluation.utils import (
|
13 |
+
determine_agent_type,
|
14 |
verify_checkpoints,
|
15 |
verify_hypothesis_answer,
|
16 |
)
|
17 |
+
from surf_spot_finder.evaluation.telemetry_utils import extract_hypothesis_answer
|
18 |
from surf_spot_finder.evaluation.test_case import TestCase
|
19 |
|
20 |
logger.remove()
|
|
|
31 |
max_driving_hours=input_data.max_driving_hours,
|
32 |
model_id=input_data.model_id,
|
33 |
api_key_var=input_data.api_key_var,
|
34 |
+
prompt=INPUT_PROMPT,
|
35 |
json_tracer=input_data.json_tracer,
|
36 |
api_base=input_data.api_base,
|
37 |
agent_type=input_data.agent_type,
|
38 |
+
tools=input_data.tools,
|
39 |
)
|
40 |
+
return find_surf_spot(
|
41 |
+
location=config.location,
|
42 |
+
date=config.date,
|
43 |
+
max_driving_hours=config.max_driving_hours,
|
44 |
+
agent_type=config.agent_type,
|
|
|
|
|
|
|
|
|
|
|
45 |
model_id=config.model_id,
|
46 |
api_key_var=config.api_key_var,
|
47 |
+
json_tracer=config.json_tracer,
|
48 |
api_base=config.api_base,
|
49 |
+
tools=config.tools,
|
|
|
|
|
|
|
|
|
50 |
)
|
|
|
51 |
|
52 |
|
53 |
def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
|
|
|
56 |
telemetry: List[Dict[str, Any]] = json.loads(f.read())
|
57 |
logger.info(f"Telemetry loaded from {telemetry_path}")
|
58 |
|
59 |
+
agent_type = determine_agent_type(telemetry)
|
60 |
+
|
61 |
# Extract the final answer from the telemetry
|
62 |
+
hypothesis_answer = extract_hypothesis_answer(
|
63 |
+
trace=telemetry, agent_type=agent_type
|
64 |
+
)
|
65 |
logger.info(
|
66 |
f"""<yellow>Hypothesis Final answer extracted: {hypothesis_answer}</yellow>"""
|
67 |
)
|
|
|
71 |
telemetry=telemetry,
|
72 |
checkpoints=test_case.checkpoints,
|
73 |
model=llm_judge,
|
74 |
+
agent_type=agent_type,
|
75 |
)
|
76 |
|
77 |
hypothesis_answer_results = verify_hypothesis_answer(
|
|
|
107 |
logger.error(message)
|
108 |
else:
|
109 |
logger.info("<green>All checkpoints passed!</green>")
|
110 |
+
logger.info(f"<green>Passed checkpoints: {len(passed_checks)}</green>")
|
111 |
+
logger.info(f"<red>Failed checkpoints: {len(failed_checks)}</red>")
|
|
|
|
|
|
|
|
|
112 |
logger.info("<green>=====================================</green>")
|
113 |
logger.info(f"<green>Score: {won_points}/{won_points + missed_points}</green>")
|
114 |
logger.info("<green>=====================================</green>")
|
src/surf_spot_finder/evaluation/telemetry_utils.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
import json
|
3 |
+
from langchain_core.messages import BaseMessage
|
4 |
+
import re
|
5 |
+
|
6 |
+
from surf_spot_finder.agents import AgentType
|
7 |
+
|
8 |
+
|
9 |
+
def extract_hypothesis_answer(
|
10 |
+
trace: List[Dict[str, Any]], agent_type: AgentType
|
11 |
+
) -> str:
|
12 |
+
"""Extract the hypothesis agent final answer from the trace"""
|
13 |
+
for span in reversed(trace):
|
14 |
+
if agent_type == AgentType.LANGCHAIN:
|
15 |
+
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
16 |
+
content = span["attributes"]["output.value"]
|
17 |
+
# If it's langchain, the actual content is a serialized langchain message that we need to extract.
|
18 |
+
message = json.loads(content)["messages"][0]
|
19 |
+
message = parse_generic_key_value_string(message)
|
20 |
+
base_message = BaseMessage(**message, type="AGENT")
|
21 |
+
print(base_message.text())
|
22 |
+
return base_message.text()
|
23 |
+
elif agent_type == AgentType.SMOLAGENTS:
|
24 |
+
if span["attributes"]["openinference.span.kind"] == "AGENT":
|
25 |
+
content = span["attributes"]["output.value"]
|
26 |
+
# If it's langchain, the actual content is a serialized langchain message that we need to extract.
|
27 |
+
return content
|
28 |
+
elif agent_type == AgentType.OPENAI:
|
29 |
+
# Looking for the final response that has the summary answer
|
30 |
+
if (
|
31 |
+
"attributes" in span
|
32 |
+
and span.get("attributes", {}).get("openinference.span.kind") == "LLM"
|
33 |
+
):
|
34 |
+
output_key = (
|
35 |
+
"llm.output_messages.0.message.contents.0.message_content.text"
|
36 |
+
)
|
37 |
+
if output_key in span["attributes"]:
|
38 |
+
return span["attributes"][output_key]
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Unsupported agent type {agent_type}")
|
41 |
+
raise ValueError("No agent final answer found in trace")
|
42 |
+
|
43 |
+
|
44 |
+
def parse_generic_key_value_string(text):
|
45 |
+
"""
|
46 |
+
Parse a string that has items of a dict with key-value pairs separated by '='.
|
47 |
+
Only splits on '=' signs, handling quoted strings properly.
|
48 |
+
I think this is to compensate for a bug in openinference? https://github.com/Arize-ai/openinference/issues/1401
|
49 |
+
"""
|
50 |
+
|
51 |
+
# Pattern to match key=value pairs, handling quoted values
|
52 |
+
# This regex looks for word characters followed by = and then captures everything
|
53 |
+
# until it finds another word character followed by = or the end of the string
|
54 |
+
# Claude helped me with this one, regex is hard
|
55 |
+
pattern = r"(\w+)=('.*?'|\".*?\"|[^'\"=]*?)(?=\s+\w+=|\s*$)"
|
56 |
+
|
57 |
+
result = {}
|
58 |
+
|
59 |
+
matches = re.findall(pattern, text)
|
60 |
+
for key, value in matches:
|
61 |
+
# Clean up the key
|
62 |
+
key = key.strip()
|
63 |
+
|
64 |
+
# Clean up the value - remove surrounding quotes if present
|
65 |
+
if (value.startswith("'") and value.endswith("'")) or (
|
66 |
+
value.startswith('"') and value.endswith('"')
|
67 |
+
):
|
68 |
+
value = value[1:-1]
|
69 |
+
|
70 |
+
# Store in result dictionary
|
71 |
+
result[key] = value
|
72 |
+
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
def extract_evidence(telemetry: List[Dict[str, Any]], agent_type: AgentType) -> str:
|
77 |
+
"""Extract relevant telemetry evidence based on the agent type."""
|
78 |
+
# Data extraction function for each agent type
|
79 |
+
extractors = {
|
80 |
+
AgentType.SMOLAGENTS: _extract_smolagents_data,
|
81 |
+
AgentType.LANGCHAIN: _extract_langchain_data,
|
82 |
+
AgentType.OPENAI: _extract_openai_data,
|
83 |
+
}
|
84 |
+
|
85 |
+
if agent_type not in extractors:
|
86 |
+
raise ValueError(f"Unsupported agent type {agent_type}")
|
87 |
+
|
88 |
+
# Extract raw data from telemetry
|
89 |
+
calls = extractors[agent_type](telemetry)
|
90 |
+
|
91 |
+
# Format data into a consistent structure
|
92 |
+
return _format_evidence(calls, agent_type)
|
93 |
+
|
94 |
+
|
95 |
+
def _extract_smolagents_data(telemetry: List[Dict[str, Any]]) -> List[Dict]:
|
96 |
+
"""Extract LLM calls and tool calls from SmoL Agents telemetry."""
|
97 |
+
calls = []
|
98 |
+
|
99 |
+
for span in telemetry:
|
100 |
+
# Skip spans without attributes
|
101 |
+
if "attributes" not in span:
|
102 |
+
continue
|
103 |
+
|
104 |
+
attributes = span["attributes"]
|
105 |
+
|
106 |
+
# Extract tool information
|
107 |
+
if "tool.name" in attributes or span.get("name", "").startswith("SimpleTool"):
|
108 |
+
tool_info = {
|
109 |
+
"tool_name": attributes.get(
|
110 |
+
"tool.name", span.get("name", "Unknown tool")
|
111 |
+
),
|
112 |
+
"status": "success"
|
113 |
+
if span.get("status", {}).get("status_code") == "OK"
|
114 |
+
else "error",
|
115 |
+
"error": span.get("status", {}).get("description", None),
|
116 |
+
}
|
117 |
+
|
118 |
+
# Extract input if available
|
119 |
+
if "input.value" in attributes:
|
120 |
+
try:
|
121 |
+
input_value = json.loads(attributes["input.value"])
|
122 |
+
if "kwargs" in input_value:
|
123 |
+
# For SmoLAgents, the actual input is often in the kwargs field
|
124 |
+
tool_info["input"] = input_value["kwargs"]
|
125 |
+
else:
|
126 |
+
tool_info["input"] = input_value
|
127 |
+
except (json.JSONDecodeError, TypeError):
|
128 |
+
tool_info["input"] = attributes["input.value"]
|
129 |
+
|
130 |
+
# Extract output if available
|
131 |
+
if "output.value" in attributes:
|
132 |
+
try:
|
133 |
+
# Try to parse JSON output
|
134 |
+
output_value = (
|
135 |
+
json.loads(attributes["output.value"])
|
136 |
+
if isinstance(attributes["output.value"], str)
|
137 |
+
else attributes["output.value"]
|
138 |
+
)
|
139 |
+
tool_info["output"] = output_value
|
140 |
+
except (json.JSONDecodeError, TypeError):
|
141 |
+
tool_info["output"] = attributes["output.value"]
|
142 |
+
else:
|
143 |
+
tool_info["output"] = "No output found"
|
144 |
+
|
145 |
+
calls.append(tool_info)
|
146 |
+
|
147 |
+
# Extract LLM calls to see reasoning
|
148 |
+
elif "LiteLLMModel.__call__" in span.get("name", ""):
|
149 |
+
# The LLM output may be in different places depending on the implementation
|
150 |
+
output_content = None
|
151 |
+
|
152 |
+
# Try to get the output from the llm.output_messages.0.message.content attribute
|
153 |
+
if "llm.output_messages.0.message.content" in attributes:
|
154 |
+
output_content = attributes["llm.output_messages.0.message.content"]
|
155 |
+
|
156 |
+
# Or try to parse it from the output.value as JSON
|
157 |
+
elif "output.value" in attributes:
|
158 |
+
try:
|
159 |
+
output_value = json.loads(attributes["output.value"])
|
160 |
+
if "content" in output_value:
|
161 |
+
output_content = output_value["content"]
|
162 |
+
except (json.JSONDecodeError, TypeError):
|
163 |
+
pass
|
164 |
+
|
165 |
+
if output_content:
|
166 |
+
calls.append(
|
167 |
+
{
|
168 |
+
"model": attributes.get("llm.model_name", "Unknown model"),
|
169 |
+
"output": output_content,
|
170 |
+
"type": "reasoning",
|
171 |
+
}
|
172 |
+
)
|
173 |
+
|
174 |
+
return calls
|
175 |
+
|
176 |
+
|
177 |
+
def _extract_langchain_data(telemetry: List[Dict[str, Any]]) -> List:
|
178 |
+
"""Extract LLM calls and tool calls from LangChain telemetry."""
|
179 |
+
calls = []
|
180 |
+
|
181 |
+
for span in telemetry:
|
182 |
+
if "attributes" not in span:
|
183 |
+
continue
|
184 |
+
|
185 |
+
attributes = span.get("attributes", {})
|
186 |
+
span_kind = attributes.get("openinference.span.kind", "")
|
187 |
+
|
188 |
+
# Collect LLM calls
|
189 |
+
if span_kind == "LLM" and "llm.output_messages.0.message.content" in attributes:
|
190 |
+
llm_info = {
|
191 |
+
"model": attributes.get("llm.model_name", "Unknown model"),
|
192 |
+
"input": attributes.get("llm.input_messages.0.message.content", ""),
|
193 |
+
"output": attributes.get("llm.output_messages.0.message.content", ""),
|
194 |
+
"type": "reasoning",
|
195 |
+
}
|
196 |
+
calls.append(llm_info)
|
197 |
+
|
198 |
+
# Try to find tool calls
|
199 |
+
if "tool.name" in attributes or span.get("name", "").endswith("Tool"):
|
200 |
+
tool_info = {
|
201 |
+
"tool_name": attributes.get(
|
202 |
+
"tool.name", span.get("name", "Unknown tool")
|
203 |
+
),
|
204 |
+
"status": "success"
|
205 |
+
if span.get("status", {}).get("status_code") == "OK"
|
206 |
+
else "error",
|
207 |
+
"error": span.get("status", {}).get("description", None),
|
208 |
+
}
|
209 |
+
|
210 |
+
if "input.value" in attributes:
|
211 |
+
try:
|
212 |
+
input_value = json.loads(attributes["input.value"])
|
213 |
+
tool_info["input"] = input_value
|
214 |
+
except Exception:
|
215 |
+
tool_info["input"] = attributes["input.value"]
|
216 |
+
|
217 |
+
if "output.value" in attributes:
|
218 |
+
tool_info["output"] = parse_generic_key_value_string(
|
219 |
+
json.loads(attributes["output.value"])["output"]
|
220 |
+
)["content"]
|
221 |
+
|
222 |
+
calls.append(tool_info)
|
223 |
+
|
224 |
+
return calls
|
225 |
+
|
226 |
+
|
227 |
+
def _extract_openai_data(telemetry: List[Dict[str, Any]]) -> list:
|
228 |
+
"""Extract LLM calls and tool calls from OpenAI telemetry."""
|
229 |
+
calls = []
|
230 |
+
|
231 |
+
for span in telemetry:
|
232 |
+
if "attributes" not in span:
|
233 |
+
continue
|
234 |
+
|
235 |
+
attributes = span.get("attributes", {})
|
236 |
+
span_kind = attributes.get("openinference.span.kind", "")
|
237 |
+
|
238 |
+
# Collect LLM interactions - look for direct message content first
|
239 |
+
if span_kind == "LLM":
|
240 |
+
# Initialize the LLM info dictionary
|
241 |
+
span_info = {}
|
242 |
+
|
243 |
+
# Try to get input message
|
244 |
+
input_key = "llm.input_messages.1.message.content" # User message is usually at index 1
|
245 |
+
if input_key in attributes:
|
246 |
+
span_info["input"] = attributes[input_key]
|
247 |
+
|
248 |
+
# Try to get output message directly
|
249 |
+
output_content = None
|
250 |
+
# Try in multiple possible locations
|
251 |
+
for key in [
|
252 |
+
"llm.output_messages.0.message.content",
|
253 |
+
"llm.output_messages.0.message.contents.0.message_content.text",
|
254 |
+
]:
|
255 |
+
if key in attributes:
|
256 |
+
output_content = attributes[key]
|
257 |
+
break
|
258 |
+
|
259 |
+
# If we found direct output content, use it
|
260 |
+
if output_content:
|
261 |
+
span_info["output"] = output_content
|
262 |
+
calls.append(span_info)
|
263 |
+
elif span_kind == "TOOL":
|
264 |
+
tool_name = attributes.get("tool.name", "Unknown tool")
|
265 |
+
tool_output = attributes.get("output.value", "")
|
266 |
+
|
267 |
+
span_info = {
|
268 |
+
"tool_name": tool_name,
|
269 |
+
"input": attributes.get("input.value", ""),
|
270 |
+
"output": tool_output,
|
271 |
+
"status": span.get("status", {}).get("status_code"),
|
272 |
+
}
|
273 |
+
span_info["input"] = json.loads(span_info["input"])
|
274 |
+
|
275 |
+
calls.append(span_info)
|
276 |
+
|
277 |
+
return calls
|
278 |
+
|
279 |
+
|
280 |
+
def _format_evidence(calls: List[Dict], agent_type: AgentType) -> str:
|
281 |
+
"""Format extracted data into a standardized output format."""
|
282 |
+
evidence = f"## {agent_type.name} Agent Execution\n\n"
|
283 |
+
|
284 |
+
for idx, call in enumerate(calls, start=1):
|
285 |
+
evidence += f"### Call {idx}\n"
|
286 |
+
|
287 |
+
# Truncate any values that are too long
|
288 |
+
max_length = 400
|
289 |
+
call = {
|
290 |
+
k: (
|
291 |
+
v[:max_length] + "..."
|
292 |
+
if isinstance(v, str) and len(v) > max_length
|
293 |
+
else v
|
294 |
+
)
|
295 |
+
for k, v in call.items()
|
296 |
+
}
|
297 |
+
|
298 |
+
# Use ensure_ascii=False to prevent escaping Unicode characters
|
299 |
+
evidence += json.dumps(call, indent=2, ensure_ascii=False) + "\n\n"
|
300 |
+
|
301 |
+
return evidence
|
src/surf_spot_finder/evaluation/test_case.py
CHANGED
@@ -15,6 +15,7 @@ class InputModel(BaseModel):
|
|
15 |
json_tracer: bool
|
16 |
api_base: Optional[str] = None
|
17 |
agent_type: str
|
|
|
18 |
|
19 |
|
20 |
class CheckpointCriteria(BaseModel):
|
|
|
15 |
json_tracer: bool
|
16 |
api_base: Optional[str] = None
|
17 |
agent_type: str
|
18 |
+
tools: Optional[List[str]] = None
|
19 |
|
20 |
|
21 |
class CheckpointCriteria(BaseModel):
|
src/surf_spot_finder/evaluation/test_cases/alpha.yaml
CHANGED
@@ -1,30 +1,65 @@
|
|
1 |
# Test case for surf spot finder
|
|
|
|
|
|
|
2 |
input:
|
3 |
location: "Vigo"
|
4 |
-
date: "2025-03-
|
5 |
max_driving_hours: 3
|
6 |
-
model_id: "openai/o3-mini"
|
7 |
api_key_var: "OPENAI_API_KEY"
|
8 |
json_tracer: true
|
9 |
api_base: null
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
ground_truth:
|
13 |
- name: "Surf location"
|
14 |
points: 5
|
15 |
-
value: "Playa de
|
16 |
-
- name: "Water temperature"
|
17 |
-
points: 1
|
18 |
-
value: "about 14°C +-5°C"
|
19 |
-
- name: "Wave height"
|
20 |
-
points: 1
|
21 |
-
value: "about 1 meter"
|
22 |
|
23 |
# Base checkpoints for agent behavior
|
24 |
# These evaluators for these checkpoints
|
25 |
# will not consider the hypothesis answer or final answer in their decision making
|
26 |
checkpoints:
|
27 |
- points: 1
|
28 |
-
criteria: "Check if the agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
- points: 1
|
30 |
-
criteria: "Check if the
|
|
|
1 |
# Test case for surf spot finder
|
2 |
+
|
3 |
+
# You only need this input data if you want to run the test case, if you pass in a path to a telemetry file this
|
4 |
+
# is ignored
|
5 |
input:
|
6 |
location: "Vigo"
|
7 |
+
date: "2025-03-27 22:00"
|
8 |
max_driving_hours: 3
|
|
|
9 |
api_key_var: "OPENAI_API_KEY"
|
10 |
json_tracer: true
|
11 |
api_base: null
|
12 |
+
# model_id: "openai/o1"
|
13 |
+
# agent_type: "smolagents"
|
14 |
+
# tools:
|
15 |
+
# - "surf_spot_finder.tools.driving_hours_to_meters"
|
16 |
+
# - "surf_spot_finder.tools.get_area_lat_lon"
|
17 |
+
# - "surf_spot_finder.tools.get_surfing_spots"
|
18 |
+
# - "surf_spot_finder.tools.get_wave_forecast"
|
19 |
+
# - "surf_spot_finder.tools.get_wind_forecast"
|
20 |
+
# - "surf_spot_finder.tools.search_web"
|
21 |
+
# - "surf_spot_finder.tools.visit_webpage"
|
22 |
+
# - "smolagents.PythonInterpreterTool"
|
23 |
+
# - "smolagents.FinalAnswerTool"
|
24 |
+
agent_type: langchain
|
25 |
+
model_id: o1
|
26 |
+
tools:
|
27 |
+
- "surf_spot_finder.tools.driving_hours_to_meters"
|
28 |
+
- "surf_spot_finder.tools.get_area_lat_lon"
|
29 |
+
- "surf_spot_finder.tools.get_surfing_spots"
|
30 |
+
- "surf_spot_finder.tools.get_wave_forecast"
|
31 |
+
- "surf_spot_finder.tools.get_wind_forecast"
|
32 |
+
- "surf_spot_finder.tools.search_web"
|
33 |
+
- "surf_spot_finder.tools.visit_webpage"
|
34 |
+
# model_id: o3-mini
|
35 |
+
# agent_type: openai
|
36 |
+
# tools:
|
37 |
+
# - "surf_spot_finder.tools.driving_hours_to_meters"
|
38 |
+
# - "surf_spot_finder.tools.get_area_lat_lon"
|
39 |
+
# - "surf_spot_finder.tools.get_surfing_spots"
|
40 |
+
# - "surf_spot_finder.tools.get_wave_forecast"
|
41 |
+
# - "surf_spot_finder.tools.get_wind_forecast"
|
42 |
+
# - "surf_spot_finder.tools.search_web"
|
43 |
+
# - "surf_spot_finder.tools.show_plan"
|
44 |
+
# - "surf_spot_finder.tools.visit_webpage"
|
45 |
ground_truth:
|
46 |
- name: "Surf location"
|
47 |
points: 5
|
48 |
+
value: "Playa de Samil"
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Base checkpoints for agent behavior
|
51 |
# These evaluators for these checkpoints
|
52 |
# will not consider the hypothesis answer or final answer in their decision making
|
53 |
checkpoints:
|
54 |
- points: 1
|
55 |
+
criteria: "Check if the agent did a web search for nearby surf locations."
|
56 |
+
- points: 1
|
57 |
+
criteria: "Check if the agent used the get_surfing_spots tool and it succeeded"
|
58 |
+
- points: 1
|
59 |
+
criteria: "Check if the agent used the get_wave_forecast tool and it succeeded"
|
60 |
+
- points: 1
|
61 |
+
criteria: "Check if the agent used the get_wind_forecast tool and it succeeded"
|
62 |
+
- points: 1
|
63 |
+
criteria: "Check if the agent used the get_area_lat_lon tool and it succeeded"
|
64 |
- points: 1
|
65 |
+
criteria: "Check if the final answer contains any description about the weather at the chosen location"
|
src/surf_spot_finder/evaluation/utils.py
CHANGED
@@ -4,10 +4,36 @@ import re
|
|
4 |
|
5 |
from litellm import completion
|
6 |
from textwrap import dedent
|
|
|
7 |
|
8 |
from pydantic import BaseModel, ConfigDict
|
|
|
9 |
from surf_spot_finder.evaluation.test_case import CheckpointCriteria
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class EvaluationResult(BaseModel):
|
13 |
"""Represents the result of evaluating a criterion"""
|
@@ -19,15 +45,6 @@ class EvaluationResult(BaseModel):
|
|
19 |
points: int
|
20 |
|
21 |
|
22 |
-
def extract_hypothesis_answer(telemetry: List[Dict[str, Any]]) -> str | None:
|
23 |
-
"""Extract the hypothesis agent final answer from the telemetry data"""
|
24 |
-
for span in reversed(telemetry):
|
25 |
-
if span.get("attributes", {}).get("openinference.span.kind") == "AGENT":
|
26 |
-
hypo = span.get("attributes", {}).get("output.value")
|
27 |
-
return hypo
|
28 |
-
raise ValueError("Final answer not found in telemetry")
|
29 |
-
|
30 |
-
|
31 |
def evaluate_criterion(
|
32 |
criteria: str,
|
33 |
model: str,
|
@@ -109,6 +126,7 @@ def verify_checkpoints(
|
|
109 |
telemetry: List[Dict[str, Any]],
|
110 |
checkpoints: List[CheckpointCriteria],
|
111 |
model: str,
|
|
|
112 |
) -> List[EvaluationResult]:
|
113 |
"""Verify each checkpoint against the telemetry data using LLM
|
114 |
These checkpoints do not take the ground truth or hyupothesis
|
@@ -117,9 +135,10 @@ def verify_checkpoints(
|
|
117 |
"""
|
118 |
results = []
|
119 |
|
|
|
|
|
120 |
for checkpoint in checkpoints:
|
121 |
criteria = checkpoint.criteria
|
122 |
-
evidence = extract_relevant_evidence(telemetry, criteria)
|
123 |
|
124 |
evaluation = evaluate_criterion(
|
125 |
criteria=criteria,
|
@@ -156,62 +175,3 @@ def verify_hypothesis_answer(
|
|
156 |
results.append(evaluation)
|
157 |
|
158 |
return results
|
159 |
-
|
160 |
-
|
161 |
-
def extract_relevant_evidence(telemetry: List[Dict[str, Any]], criteria: str) -> str:
|
162 |
-
"""Extract relevant telemetry evidence based on the checkpoint criteria
|
163 |
-
TODO this is not a very robust implementation, since it requires knowledge about which tools have been
|
164 |
-
implemented. We should abstract this so that it can dynamically figure out what tools may have been used
|
165 |
-
and check for them appropriately. I understand that this tool should probably have some better way of abstracting
|
166 |
-
relevant information from the opentelemetry spans."""
|
167 |
-
evidence = ""
|
168 |
-
|
169 |
-
# Look for evidence of tool usage
|
170 |
-
if "DuckDuckGoSearchTool" in criteria:
|
171 |
-
search_spans = [
|
172 |
-
span for span in telemetry if span.get("name") == "DuckDuckGoSearchTool"
|
173 |
-
]
|
174 |
-
evidence += f"Search tool was used {len(search_spans)} times.\n"
|
175 |
-
for i, span in enumerate(search_spans): # Limit to first 3 searches
|
176 |
-
if "attributes" in span and "input.value" in span["attributes"]:
|
177 |
-
try:
|
178 |
-
input_value = json.loads(span["attributes"]["input.value"])
|
179 |
-
if "kwargs" in input_value and "query" in input_value["kwargs"]:
|
180 |
-
evidence += (
|
181 |
-
f"Search query {i + 1}: {input_value['kwargs']['query']}\n"
|
182 |
-
)
|
183 |
-
except (json.JSONDecodeError, TypeError):
|
184 |
-
pass
|
185 |
-
|
186 |
-
# Look for evidence of website fetching
|
187 |
-
if "fetched a website" in criteria:
|
188 |
-
fetch_spans = [
|
189 |
-
span
|
190 |
-
for span in telemetry
|
191 |
-
if span.get("attributes", {}).get("tool.name") == "fetch"
|
192 |
-
]
|
193 |
-
evidence += f"Website fetch tool was used {len(fetch_spans)} times.\n"
|
194 |
-
for i, span in enumerate(fetch_spans): # Limit to first 3 fetches
|
195 |
-
if "attributes" in span and "input.value" in span["attributes"]:
|
196 |
-
try:
|
197 |
-
input_value = json.loads(span["attributes"]["input.value"])
|
198 |
-
if "kwargs" in input_value and "url" in input_value["kwargs"]:
|
199 |
-
evidence += (
|
200 |
-
f"Fetched URL {i + 1}: {input_value['kwargs']['url']}\n"
|
201 |
-
)
|
202 |
-
except (json.JSONDecodeError, TypeError):
|
203 |
-
pass
|
204 |
-
|
205 |
-
# Add general evidence about all tool calls
|
206 |
-
tool_calls = {}
|
207 |
-
for span in telemetry:
|
208 |
-
if "name" in span and span["name"] not in tool_calls:
|
209 |
-
tool_calls[span["name"]] = 1
|
210 |
-
elif "name" in span:
|
211 |
-
tool_calls[span["name"]] += 1
|
212 |
-
|
213 |
-
evidence += "\nTool calls summary:\n"
|
214 |
-
for tool, count in tool_calls.items():
|
215 |
-
evidence += f"- {tool}: {count} call(s)\n"
|
216 |
-
|
217 |
-
return evidence
|
|
|
4 |
|
5 |
from litellm import completion
|
6 |
from textwrap import dedent
|
7 |
+
from loguru import logger
|
8 |
|
9 |
from pydantic import BaseModel, ConfigDict
|
10 |
+
from surf_spot_finder.evaluation.telemetry_utils import extract_evidence
|
11 |
from surf_spot_finder.evaluation.test_case import CheckpointCriteria
|
12 |
|
13 |
+
from surf_spot_finder.agents import AgentType
|
14 |
+
|
15 |
+
|
16 |
+
def determine_agent_type(trace: List[Dict[str, Any]]) -> AgentType:
|
17 |
+
"""Determine the agent type based on the trace.
|
18 |
+
These are not really stable ways to find it, because we're waiting on some
|
19 |
+
reliable method for determining the agent type. This is a temporary solution.
|
20 |
+
"""
|
21 |
+
for span in trace:
|
22 |
+
if "langchain" in span.get("attributes", {}).get("input.value", ""):
|
23 |
+
logger.info("Agent type is LANGCHAIN")
|
24 |
+
return AgentType.LANGCHAIN
|
25 |
+
if span.get("attributes", {}).get("smolagents.max_steps"):
|
26 |
+
logger.info("Agent type is SMOLAGENTS")
|
27 |
+
return AgentType.SMOLAGENTS
|
28 |
+
# This is extremely fragile but there currently isn't
|
29 |
+
# any specific key to indicate the agent type
|
30 |
+
if span.get("name") == "response":
|
31 |
+
logger.info("Agent type is OPENAI")
|
32 |
+
return AgentType.OPENAI
|
33 |
+
raise ValueError(
|
34 |
+
"Could not determine agent type from trace, or agent type not supported"
|
35 |
+
)
|
36 |
+
|
37 |
|
38 |
class EvaluationResult(BaseModel):
|
39 |
"""Represents the result of evaluating a criterion"""
|
|
|
45 |
points: int
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def evaluate_criterion(
|
49 |
criteria: str,
|
50 |
model: str,
|
|
|
126 |
telemetry: List[Dict[str, Any]],
|
127 |
checkpoints: List[CheckpointCriteria],
|
128 |
model: str,
|
129 |
+
agent_type: AgentType,
|
130 |
) -> List[EvaluationResult]:
|
131 |
"""Verify each checkpoint against the telemetry data using LLM
|
132 |
These checkpoints do not take the ground truth or hyupothesis
|
|
|
135 |
"""
|
136 |
results = []
|
137 |
|
138 |
+
evidence = extract_evidence(telemetry, agent_type)
|
139 |
+
print(evidence)
|
140 |
for checkpoint in checkpoints:
|
141 |
criteria = checkpoint.criteria
|
|
|
142 |
|
143 |
evaluation = evaluate_criterion(
|
144 |
criteria=criteria,
|
|
|
175 |
results.append(evaluation)
|
176 |
|
177 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/surf_spot_finder/tracing.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
from datetime import datetime
|
4 |
-
|
5 |
-
from opentelemetry import trace
|
6 |
from opentelemetry.sdk.trace import TracerProvider
|
7 |
-
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
8 |
-
|
|
|
9 |
|
10 |
|
11 |
class JsonFileSpanExporter(SpanExporter):
|
@@ -44,7 +43,10 @@ class JsonFileSpanExporter(SpanExporter):
|
|
44 |
|
45 |
|
46 |
def get_tracer_provider(
|
47 |
-
project_name: str,
|
|
|
|
|
|
|
48 |
) -> tuple[TracerProvider, str | None]:
|
49 |
"""
|
50 |
Create a tracer_provider based on the selected mode.
|
@@ -52,6 +54,7 @@ def get_tracer_provider(
|
|
52 |
Args:
|
53 |
project_name: Name of the project for tracing
|
54 |
json_tracer: Whether to use the custom JSON file exporter (True) or Phoenix (False)
|
|
|
55 |
output_dir: The directory where the telemetry output will be stored.
|
56 |
Only used if `json_tracer=True`.
|
57 |
Defaults to "telemetry_output".
|
@@ -66,9 +69,7 @@ def get_tracer_provider(
|
|
66 |
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
67 |
|
68 |
tracer_provider = TracerProvider()
|
69 |
-
|
70 |
-
|
71 |
-
file_name = f"{output_dir}/{project_name}-{timestamp}.json"
|
72 |
json_file_exporter = JsonFileSpanExporter(file_name=file_name)
|
73 |
span_processor = SimpleSpanProcessor(json_file_exporter)
|
74 |
tracer_provider.add_span_processor(span_processor)
|
@@ -97,14 +98,17 @@ def setup_tracing(tracer_provider: TracerProvider, agent_type: str) -> None:
|
|
97 |
validate_agent_type(agent_type)
|
98 |
|
99 |
if "openai" in agent_type:
|
100 |
-
from openinference.instrumentation.openai_agents import
|
101 |
-
|
102 |
-
|
103 |
elif agent_type == "smolagents":
|
104 |
-
from openinference.instrumentation.smolagents import
|
105 |
-
|
106 |
-
|
107 |
elif agent_type == "langchain":
|
108 |
-
from openinference.instrumentation.langchain import
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import json
|
3 |
from datetime import datetime
|
|
|
|
|
4 |
from opentelemetry.sdk.trace import TracerProvider
|
5 |
+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter
|
6 |
+
|
7 |
+
from surf_spot_finder.agents import AgentType
|
8 |
|
9 |
|
10 |
class JsonFileSpanExporter(SpanExporter):
|
|
|
43 |
|
44 |
|
45 |
def get_tracer_provider(
|
46 |
+
project_name: str,
|
47 |
+
json_tracer: bool,
|
48 |
+
agent_type: AgentType,
|
49 |
+
output_dir: str = "telemetry_output",
|
50 |
) -> tuple[TracerProvider, str | None]:
|
51 |
"""
|
52 |
Create a tracer_provider based on the selected mode.
|
|
|
54 |
Args:
|
55 |
project_name: Name of the project for tracing
|
56 |
json_tracer: Whether to use the custom JSON file exporter (True) or Phoenix (False)
|
57 |
+
agent_type: The type of agent being used.
|
58 |
output_dir: The directory where the telemetry output will be stored.
|
59 |
Only used if `json_tracer=True`.
|
60 |
Defaults to "telemetry_output".
|
|
|
69 |
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
70 |
|
71 |
tracer_provider = TracerProvider()
|
72 |
+
file_name = f"{output_dir}/{agent_type}-{project_name}-{timestamp}.json"
|
|
|
|
|
73 |
json_file_exporter = JsonFileSpanExporter(file_name=file_name)
|
74 |
span_processor = SimpleSpanProcessor(json_file_exporter)
|
75 |
tracer_provider.add_span_processor(span_processor)
|
|
|
98 |
validate_agent_type(agent_type)
|
99 |
|
100 |
if "openai" in agent_type:
|
101 |
+
from openinference.instrumentation.openai_agents import (
|
102 |
+
OpenAIAgentsInstrumentor as Instrumentor,
|
103 |
+
)
|
104 |
elif agent_type == "smolagents":
|
105 |
+
from openinference.instrumentation.smolagents import (
|
106 |
+
SmolagentsInstrumentor as Instrumentor,
|
107 |
+
)
|
108 |
elif agent_type == "langchain":
|
109 |
+
from openinference.instrumentation.langchain import (
|
110 |
+
LangChainInstrumentor as Instrumentor,
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
raise ValueError(f"Unsupported agent type: {agent_type}")
|
114 |
+
Instrumentor().instrument(tracer_provider=tracer_provider)
|
tests/unit/test_unit_tracing.py
CHANGED
@@ -2,31 +2,27 @@ from unittest.mock import patch, MagicMock
|
|
2 |
|
3 |
import pytest
|
4 |
|
|
|
5 |
from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
|
6 |
|
7 |
|
8 |
@pytest.mark.parametrize("json_tracer", [True, False])
|
9 |
def test_get_tracer_provider(tmp_path, json_tracer):
|
10 |
-
mock_trace = MagicMock()
|
11 |
mock_tracer_provider = MagicMock()
|
12 |
mock_register = MagicMock()
|
13 |
|
14 |
with (
|
15 |
-
patch("surf_spot_finder.tracing.trace", mock_trace),
|
16 |
patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
|
17 |
patch("phoenix.otel.register", mock_register),
|
18 |
):
|
19 |
get_tracer_provider(
|
20 |
project_name="test_project",
|
21 |
json_tracer=json_tracer,
|
|
|
22 |
output_dir=tmp_path / "telemetry",
|
23 |
)
|
24 |
assert (tmp_path / "telemetry").exists() == json_tracer
|
25 |
-
if json_tracer:
|
26 |
-
mock_trace.set_tracer_provider.assert_called_once_with(
|
27 |
-
mock_tracer_provider.return_value
|
28 |
-
)
|
29 |
-
else:
|
30 |
mock_register.assert_called_once_with(
|
31 |
project_name="test_project", set_global_tracer_provider=True
|
32 |
)
|
|
|
2 |
|
3 |
import pytest
|
4 |
|
5 |
+
from surf_spot_finder.agents import AgentType
|
6 |
from surf_spot_finder.tracing import get_tracer_provider, setup_tracing
|
7 |
|
8 |
|
9 |
@pytest.mark.parametrize("json_tracer", [True, False])
|
10 |
def test_get_tracer_provider(tmp_path, json_tracer):
|
|
|
11 |
mock_tracer_provider = MagicMock()
|
12 |
mock_register = MagicMock()
|
13 |
|
14 |
with (
|
|
|
15 |
patch("surf_spot_finder.tracing.TracerProvider", mock_tracer_provider),
|
16 |
patch("phoenix.otel.register", mock_register),
|
17 |
):
|
18 |
get_tracer_provider(
|
19 |
project_name="test_project",
|
20 |
json_tracer=json_tracer,
|
21 |
+
agent_type=AgentType.SMOLAGENTS,
|
22 |
output_dir=tmp_path / "telemetry",
|
23 |
)
|
24 |
assert (tmp_path / "telemetry").exists() == json_tracer
|
25 |
+
if not json_tracer:
|
|
|
|
|
|
|
|
|
26 |
mock_register.assert_called_once_with(
|
27 |
project_name="test_project", set_global_tracer_provider=True
|
28 |
)
|