|
Below is a practical, lightweight recipe you can adapt to measure **exact-match accuracy** (the metric GAIA uses) on your new evaluation file. |
|
|
|
--- |
|
|
|
### 1 Define a thin wrapper around your agent |
|
|
|
```python |
|
# agent_wrapper.py |
|
from typing import Dict |
|
|
|
class MyAgent: |
|
""" |
|
Replace the `answer` method with however you call your own agent |
|
(API call, local model .predict(), etc.). |
|
""" |
|
def answer(self, record: Dict) -> str: |
|
prompt = record["question"] |
|
# ► ► your code here ◄ ◄ |
|
response = ... # the raw answer string |
|
return response.strip() |
|
``` |
|
|
|
--- |
|
|
|
### 2 Normalization helpers (GAIA style) |
|
|
|
```python |
|
# normalize.py |
|
import re |
|
|
|
def normalize(ans: str) -> str: |
|
""" |
|
GAIA scoring ≈ quasi-exact match after: |
|
• trim / collapse whitespace |
|
• lowercase (safe for numbers, too) |
|
Extend if you need custom rules (e.g. strip trailing $ or %). |
|
""" |
|
ans = ans.strip().lower() |
|
ans = re.sub(r"\\s+", " ", ans) # collapse inner spaces |
|
return ans |
|
``` |
|
|
|
--- |
|
|
|
### 3 Evaluation script |
|
|
|
```python |
|
# evaluate_agent.py |
|
import json, argparse, pathlib, time |
|
from typing import Dict, List |
|
|
|
from agent_wrapper import MyAgent |
|
from normalize import normalize |
|
|
|
def load_records(path: pathlib.Path) -> List[Dict]: |
|
with path.open("r", encoding="utf-8") as f: |
|
return json.load(f) # your new file is a JSON array |
|
|
|
def main(path_eval: str, limit: int | None = None): |
|
eval_path = pathlib.Path(path_eval) |
|
records = load_records(eval_path) |
|
if limit: |
|
records = records[:limit] |
|
|
|
agent = MyAgent() |
|
n_total = len(records) |
|
n_correct = 0 |
|
latencies = [] |
|
|
|
for rec in records: |
|
t0 = time.perf_counter() |
|
pred = agent.answer(rec) |
|
latencies.append(time.perf_counter() - t0) |
|
|
|
gold = rec.get("Final answer") or rec.get("Final answer.".lower()) \ |
|
or rec.get("Final answer".lower()) or rec.get("Final answer", "") |
|
if normalize(pred) == normalize(gold): |
|
n_correct += 1 |
|
|
|
acc = n_correct / n_total * 100 |
|
print(f"Accuracy: {n_correct}/{n_total} ({acc:.2f}%)") |
|
print(f"Median latency: {sorted(latencies)[len(latencies)//2]:.2f}s") |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("eval_json", help="common_questions.json (or other)") |
|
parser.add_argument("--limit", type=int, help="debug with first N records") |
|
args = parser.parse_args() |
|
main(args.eval_json, args.limit) |
|
``` |
|
|
|
*Run*: |
|
|
|
```bash |
|
python3 evaluate_agent.py question_set/common_questions.json |
|
``` |
|
|
|
--- |
|
|
|
### 4 Customizing |
|
|
|
| Need | Where to tweak | |
|
| ----------------------------------------------------------------------- | ----------------------------------------- | |
|
| **Agent call** (local model vs. API with keys, tool-use, etc.) | `MyAgent.answer()` | |
|
| **More elaborate normalization** (e.g. strip `$` or `%`, round numbers) | `normalize()` | |
|
| **Partial credit / numeric tolerance** | Replace the `==` line with your own logic | |
|
|
|
--- |
|
|
|
### 5 Interpreting results |
|
|
|
* **Exact-match accuracy** (>= 100 % means your agent reproduced all answers). |
|
* **Latency** helps you spot outliers in run time (e.g. long tool chains). |