Nathan Brake commited on
Commit
ef766f7
·
unverified ·
1 Parent(s): ffb4e87

The test case no longer specifies which agent is involved (#30)

Browse files

* The test case no longer specifies which agent is involved

* format

src/surf_spot_finder/evaluation/evaluate.py CHANGED
@@ -115,7 +115,9 @@ def evaluate_telemetry(test_case: TestCase, telemetry_path: str) -> bool:
115
  logger.info("<green>=====================================</green>")
116
 
117
 
118
- def evaluate(test_case_path: str, telemetry_path: Optional[str] = None) -> None:
 
 
119
  """
120
  Evaluate agent performance using either a provided telemetry file or by running the agent.
121
 
@@ -123,7 +125,9 @@ def evaluate(test_case_path: str, telemetry_path: Optional[str] = None) -> None:
123
  telemetry_path: Optional path to an existing telemetry file. If not provided,
124
  the agent will be run to generate one.
125
  """
126
- test_case = TestCase.from_yaml(test_case_path)
 
 
127
 
128
  if telemetry_path is None:
129
  logger.info(
 
115
  logger.info("<green>=====================================</green>")
116
 
117
 
118
+ def evaluate(
119
+ test_case_path: str, agent_config_path: str, telemetry_path: Optional[str] = None
120
+ ) -> None:
121
  """
122
  Evaluate agent performance using either a provided telemetry file or by running the agent.
123
 
 
125
  telemetry_path: Optional path to an existing telemetry file. If not provided,
126
  the agent will be run to generate one.
127
  """
128
+ test_case = TestCase.from_yaml(
129
+ test_case_path=test_case_path, agent_config_path=agent_config_path
130
+ )
131
 
132
  if telemetry_path is None:
133
  logger.info(
src/surf_spot_finder/evaluation/test_case.py CHANGED
@@ -15,7 +15,7 @@ class InputModel(BaseModel):
15
 
16
  class AgentModel(BaseModel):
17
  model_id: str
18
- api_key_var: str
19
  api_base: Optional[str] = None
20
  agent_type: str
21
  tools: Optional[List[str]] = None
@@ -38,10 +38,14 @@ class TestCase(BaseModel):
38
  final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
39
 
40
  @classmethod
41
- def from_yaml(cls, case_path: str) -> "TestCase":
42
  """Load a test case from a YAML file and process it"""
43
- with open(case_path, "r") as f:
44
  test_case_dict = yaml.safe_load(f)
 
 
 
 
45
  final_answer_criteria = []
46
 
47
  def add_gt_final_answer_criteria(ground_truth_list):
 
15
 
16
  class AgentModel(BaseModel):
17
  model_id: str
18
+ api_key_var: str = "OPENAI_API_KEY"
19
  api_base: Optional[str] = None
20
  agent_type: str
21
  tools: Optional[List[str]] = None
 
38
  final_answer_criteria: List[CheckpointCriteria] = Field(default_factory=list)
39
 
40
  @classmethod
41
+ def from_yaml(cls, test_case_path: str, agent_config_path: str) -> "TestCase":
42
  """Load a test case from a YAML file and process it"""
43
+ with open(test_case_path, "r") as f:
44
  test_case_dict = yaml.safe_load(f)
45
+
46
+ with open(agent_config_path, "r") as f:
47
+ agent_config_dict = yaml.safe_load(f)
48
+ test_case_dict["agent"] = agent_config_dict["agent"]
49
  final_answer_criteria = []
50
 
51
  def add_gt_final_answer_criteria(ground_truth_list):
src/surf_spot_finder/evaluation/test_cases/alpha.yaml CHANGED
@@ -7,21 +7,6 @@ input:
7
  date: "2025-03-27 22:00"
8
  max_driving_hours: 3
9
  json_tracer: true
10
- agent:
11
- api_key_var: "OPENAI_API_KEY"
12
- api_base: null
13
- model_id: "openai/o1"
14
- agent_type: "smolagents"
15
- tools:
16
- - "surf_spot_finder.tools.driving_hours_to_meters"
17
- - "surf_spot_finder.tools.get_area_lat_lon"
18
- - "surf_spot_finder.tools.get_surfing_spots"
19
- - "surf_spot_finder.tools.get_wave_forecast"
20
- - "surf_spot_finder.tools.get_wind_forecast"
21
- - "surf_spot_finder.tools.search_web"
22
- - "surf_spot_finder.tools.visit_webpage"
23
- - "smolagents.PythonInterpreterTool"
24
- - "smolagents.FinalAnswerTool"
25
 
26
 
27
  ground_truth:
 
7
  date: "2025-03-27 22:00"
8
  max_driving_hours: 3
9
  json_tracer: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  ground_truth: