Spaces:
Sleeping
Sleeping
Upload 9 files
Browse files- examples/gaia/README.md +74 -0
- examples/gaia/__init__.py +8 -0
- examples/gaia/gaia_agent_runner.py +200 -0
- examples/gaia/gaia_agent_server.py +141 -0
- examples/gaia/mcp.json +140 -0
- examples/gaia/prompt.py +26 -0
- examples/gaia/run.py +263 -0
- examples/gaia/stream_aworld_ui.py +66 -0
- examples/gaia/utils.py +191 -0
examples/gaia/README.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Running Gaia Benchmark
|
2 |
+
|
3 |
+
AWorld is an AI-powered simulation environment for creating and interacting with virtual worlds.
|
4 |
+
|
5 |
+
You could follow the following steps to run the Gaia benchmark.
|
6 |
+
|
7 |
+
## Prerequisites
|
8 |
+
|
9 |
+
- Git
|
10 |
+
- Conda (Miniconda or Anaconda)
|
11 |
+
- Python 3.11
|
12 |
+
- Node.js
|
13 |
+
|
14 |
+
## Installation
|
15 |
+
|
16 |
+
### 1. Clone the repository:
|
17 |
+
```bash
|
18 |
+
git clone https://github.com/inclusionAI/AWorld.git
|
19 |
+
cd AWorld
|
20 |
+
```
|
21 |
+
|
22 |
+
### 2. Create and activate a Conda environment:
|
23 |
+
```bash
|
24 |
+
conda create --name aworld python=3.11
|
25 |
+
conda activate aworld
|
26 |
+
```
|
27 |
+
|
28 |
+
### 3. Setup the environment:
|
29 |
+
```bash
|
30 |
+
python setup.py install
|
31 |
+
|
32 |
+
git clone https://github.com/haris-musa/excel-mcp-server.git
|
33 |
+
cd excel-mcp-server
|
34 |
+
uv pip install -e .
|
35 |
+
```
|
36 |
+
|
37 |
+
### 4. Download GAIA Dataset
|
38 |
+
https://huggingface.co/datasets/gaia-benchmark/GAIA/tree/main
|
39 |
+
|
40 |
+
## Configuration
|
41 |
+
|
42 |
+
### Create a `.env` file in the project root and add your API credentials:
|
43 |
+
```bash
|
44 |
+
GAIA_DATASET_PATH=<Your Dataset Absolute Path>
|
45 |
+
LLM_API_KEY=<Your API Key>
|
46 |
+
LLM_BASE_URL=<Your Service Provider URL>
|
47 |
+
...
|
48 |
+
```
|
49 |
+
|
50 |
+
## Running the Application
|
51 |
+
|
52 |
+
1. Start the local MCP servers
|
53 |
+
```bash
|
54 |
+
uv run excel-mcp-server
|
55 |
+
```
|
56 |
+
The server should now be running with the configuration specified in `mcp.json`.
|
57 |
+
|
58 |
+
1. Run the script
|
59 |
+
```bash
|
60 |
+
python ./examples/gaia/run.py
|
61 |
+
```
|
62 |
+
Now you could check the output log in the console.
|
63 |
+
|
64 |
+
## Troubleshooting
|
65 |
+
|
66 |
+
If you encounter issues:
|
67 |
+
|
68 |
+
- Verify all environment variables are set correctly
|
69 |
+
- Check that all dependencies are installed
|
70 |
+
- Ensure the dataset path is correctly configured
|
71 |
+
- Check the server logs for any error messages
|
72 |
+
|
73 |
+
## Support
|
74 |
+
For additional help, please [open an issue](https://github.com/inclusionAI/AWorld/issues/new) on our GitHub repository.
|
examples/gaia/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
from aworld.utils import import_packages
|
4 |
+
|
5 |
+
import_packages(["dotenv"])
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
load_dotenv()
|
examples/gaia/gaia_agent_runner.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import traceback
|
6 |
+
from typing import AsyncGenerator
|
7 |
+
import uuid
|
8 |
+
from aworld.config.conf import AgentConfig, TaskConfig
|
9 |
+
from aworld.agents.llm_agent import Agent
|
10 |
+
from aworld.core.task import Task
|
11 |
+
from aworld.runner import Runners
|
12 |
+
from aworld.output.ui.base import AworldUI
|
13 |
+
from aworld.output.ui.markdown_aworld_ui import MarkdownAworldUI
|
14 |
+
from aworld.output.base import Output
|
15 |
+
from .utils import (
|
16 |
+
add_file_path,
|
17 |
+
load_dataset_meta_dict,
|
18 |
+
question_scorer,
|
19 |
+
)
|
20 |
+
from .prompt import system_prompt
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class GaiaAgentRunner:
|
26 |
+
"""
|
27 |
+
Gaia Agent Runner
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
llm_provider: str,
|
33 |
+
llm_model_name: str,
|
34 |
+
llm_base_url: str,
|
35 |
+
llm_api_key: str,
|
36 |
+
llm_temperature: float = 0.0,
|
37 |
+
mcp_config: dict = {},
|
38 |
+
):
|
39 |
+
self.agent_config = AgentConfig(
|
40 |
+
llm_provider=llm_provider,
|
41 |
+
llm_model_name=llm_model_name,
|
42 |
+
llm_api_key=llm_api_key,
|
43 |
+
llm_base_url=llm_base_url,
|
44 |
+
llm_temperature=llm_temperature,
|
45 |
+
)
|
46 |
+
|
47 |
+
self.super_agent = Agent(
|
48 |
+
conf=self.agent_config,
|
49 |
+
name="gaia_super_agent",
|
50 |
+
system_prompt=system_prompt,
|
51 |
+
mcp_config=mcp_config,
|
52 |
+
mcp_servers=mcp_config.get("mcpServers", {}).keys(),
|
53 |
+
)
|
54 |
+
|
55 |
+
self.gaia_dataset_path = os.path.abspath(
|
56 |
+
os.getenv(
|
57 |
+
"GAIA_DATASET_PATH",
|
58 |
+
os.path.join(os.path.dirname(os.path.abspath(__file__)), "GAIA", "2023"),
|
59 |
+
)
|
60 |
+
)
|
61 |
+
self.full_dataset = load_dataset_meta_dict(self.gaia_dataset_path)
|
62 |
+
logger.info(
|
63 |
+
f"Gaia Agent Runner initialized: super_agent={self.super_agent}, agent_config={self.agent_config}, gaia_dataset_path={self.gaia_dataset_path}, full_dataset={len(self.full_dataset)}"
|
64 |
+
)
|
65 |
+
|
66 |
+
async def run(self, prompt: str):
|
67 |
+
yield (f"\n### GAIA Agent Start!")
|
68 |
+
|
69 |
+
mcp_servers = "\n- ".join(self.super_agent.mcp_servers)
|
70 |
+
yield (f"\n```gaia_agent_status\n- {mcp_servers}\n```\n")
|
71 |
+
|
72 |
+
question = None
|
73 |
+
data_item = None
|
74 |
+
task_id = None
|
75 |
+
try:
|
76 |
+
json_data = json.loads(prompt)
|
77 |
+
task_id = json_data["task_id"]
|
78 |
+
|
79 |
+
data_item = self.full_dataset[task_id]
|
80 |
+
question = add_file_path(data_item, file_path=self.gaia_dataset_path)[
|
81 |
+
"Question"
|
82 |
+
]
|
83 |
+
yield (f"\n```gaia_question\n{json.dumps(data_item, indent=2)}\n```\n")
|
84 |
+
except Exception as e:
|
85 |
+
pass
|
86 |
+
|
87 |
+
if not question:
|
88 |
+
logger.warning(
|
89 |
+
"Could not find GAIA question for prompt, chat using prompt directly!"
|
90 |
+
)
|
91 |
+
yield (f"\n{prompt}\n")
|
92 |
+
question = prompt
|
93 |
+
|
94 |
+
try:
|
95 |
+
task = Task(
|
96 |
+
id=task_id + "." + uuid.uuid1().hex if task_id else uuid.uuid1().hex,
|
97 |
+
input=question,
|
98 |
+
agent=self.super_agent,
|
99 |
+
event_driven=False,
|
100 |
+
conf=TaskConfig(max_steps=20),
|
101 |
+
)
|
102 |
+
|
103 |
+
last_output: Output = None
|
104 |
+
rich_ui = MarkdownAworldUI()
|
105 |
+
async for output in Runners.streamed_run_task(task).stream_events():
|
106 |
+
logger.info(f"Gaia Agent Ouput: {output}")
|
107 |
+
res = await AworldUI.parse_output(output, rich_ui)
|
108 |
+
for item in res if isinstance(res, list) else [res]:
|
109 |
+
if isinstance(item, AsyncGenerator):
|
110 |
+
async for sub_item in item:
|
111 |
+
yield sub_item
|
112 |
+
else:
|
113 |
+
yield item
|
114 |
+
last_output = item
|
115 |
+
|
116 |
+
logger.info(f"Gaia Agent Last Output: {last_output}")
|
117 |
+
|
118 |
+
if data_item and last_output:
|
119 |
+
final_response = self._judge_answer(data_item, last_output)
|
120 |
+
yield final_response
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
logger.error(f"Error processing {prompt}, error: {traceback.format_exc()}")
|
124 |
+
|
125 |
+
def _judge_answer(self, data_item: dict, result: Output):
|
126 |
+
answer = result
|
127 |
+
match = re.search(r"<answer>(.*?)</answer>", answer)
|
128 |
+
if match:
|
129 |
+
answer = match.group(1)
|
130 |
+
logger.info(f"Agent answer: {answer}")
|
131 |
+
logger.info(f"Correct answer: {data_item['Final answer']}")
|
132 |
+
|
133 |
+
if question_scorer(answer, data_item["Final answer"]):
|
134 |
+
logger.info(f"Question {data_item['task_id']} Correct!")
|
135 |
+
else:
|
136 |
+
logger.info(f"Question {data_item['task_id']} Incorrect!")
|
137 |
+
|
138 |
+
# Create the new result record
|
139 |
+
correct = question_scorer(answer, data_item["Final answer"])
|
140 |
+
new_result = {
|
141 |
+
"task_id": data_item["task_id"],
|
142 |
+
"level": data_item["Level"],
|
143 |
+
"question": data_item["Question"],
|
144 |
+
"answer": data_item["Final answer"],
|
145 |
+
"response": answer,
|
146 |
+
"is_correct": correct,
|
147 |
+
}
|
148 |
+
return f"\n## Final Result: {'✅' if correct else '❌'}\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```"
|
149 |
+
else:
|
150 |
+
new_result = answer
|
151 |
+
return f"\n## Final Result:\n \n```gaia_result\n{json.dumps(new_result, indent=2)}\n```"
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
import asyncio
|
156 |
+
import argparse
|
157 |
+
from datetime import datetime
|
158 |
+
|
159 |
+
logger = logging.getLogger(__name__)
|
160 |
+
|
161 |
+
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
|
162 |
+
if not os.path.exists(output_dir):
|
163 |
+
os.makedirs(output_dir)
|
164 |
+
|
165 |
+
output_file = os.path.join(
|
166 |
+
output_dir, f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
|
167 |
+
)
|
168 |
+
|
169 |
+
async def main():
|
170 |
+
parser = argparse.ArgumentParser()
|
171 |
+
parser.add_argument("--prompt", type=str, default="")
|
172 |
+
args = parser.parse_args()
|
173 |
+
|
174 |
+
try:
|
175 |
+
prompt = args.prompt
|
176 |
+
|
177 |
+
llm_provider = os.getenv("LLM_PROVIDER")
|
178 |
+
llm_model_name = os.getenv("LLM_MODEL_NAME")
|
179 |
+
llm_api_key = os.getenv("LLM_API_KEY")
|
180 |
+
llm_base_url = os.getenv("LLM_BASE_URL")
|
181 |
+
llm_temperature = os.getenv("LLM_TEMPERATURE", 0.0)
|
182 |
+
|
183 |
+
def send_output(output):
|
184 |
+
with open(output_file, "a") as f:
|
185 |
+
f.write(f"{output}\n")
|
186 |
+
|
187 |
+
async for i in GaiaAgentRunner(
|
188 |
+
llm_provider=llm_provider,
|
189 |
+
llm_model_name=llm_model_name,
|
190 |
+
llm_base_url=llm_base_url,
|
191 |
+
llm_api_key=llm_api_key,
|
192 |
+
llm_temperature=llm_temperature,
|
193 |
+
).run(prompt):
|
194 |
+
send_output(i)
|
195 |
+
except Exception as e:
|
196 |
+
logger.error(
|
197 |
+
f"Error processing {args.prompt}, error: {traceback.format_exc()}"
|
198 |
+
)
|
199 |
+
|
200 |
+
asyncio.run(main())
|
examples/gaia/gaia_agent_server.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import AsyncGenerator
|
2 |
+
import traceback
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import sys
|
7 |
+
import uuid
|
8 |
+
import time
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class GaiaAgentServer:
|
14 |
+
|
15 |
+
def _get_model_config(self):
|
16 |
+
try:
|
17 |
+
llm_provider = os.getenv("LLM_PROVIDER_GAIA")
|
18 |
+
llm_model_name = os.getenv("LLM_MODEL_NAME_GAIA")
|
19 |
+
llm_api_key = os.getenv("LLM_API_KEY_GAIA")
|
20 |
+
llm_base_url = os.getenv("LLM_BASE_URL_GAIA")
|
21 |
+
llm_temperature = os.getenv("LLM_TEMPERATURE_GAIA", 0.0)
|
22 |
+
return {
|
23 |
+
"provider": llm_provider,
|
24 |
+
"model": llm_model_name,
|
25 |
+
"api_key": llm_api_key,
|
26 |
+
"base_url": llm_base_url,
|
27 |
+
"temperature": llm_temperature,
|
28 |
+
}
|
29 |
+
except Exception as e:
|
30 |
+
logger.warning(
|
31 |
+
f">>> Gaia Agent: GAIA_MODEL_CONFIG is not configured, using LLM"
|
32 |
+
)
|
33 |
+
raise e
|
34 |
+
|
35 |
+
def models(self):
|
36 |
+
model = self._get_model_config()
|
37 |
+
|
38 |
+
return [
|
39 |
+
{
|
40 |
+
"id": f"{model['provider']}/{model['model']}",
|
41 |
+
"name": f"gaia_agent@{model['provider']}/{model['model']}",
|
42 |
+
}
|
43 |
+
]
|
44 |
+
|
45 |
+
async def chat_completions(self, body: dict) -> AsyncGenerator[str, None]:
|
46 |
+
def response_line(line: str, model: str):
|
47 |
+
return {
|
48 |
+
"object": "chat.completion.chunk",
|
49 |
+
"id": str(uuid.uuid4()).replace("-", ""),
|
50 |
+
"choices": [
|
51 |
+
{"index": 0, "delta": {"content": line, "role": "assistant"}}
|
52 |
+
],
|
53 |
+
"created": int(time.time()),
|
54 |
+
"model": model,
|
55 |
+
}
|
56 |
+
|
57 |
+
try:
|
58 |
+
logger.info(f">>> Gaia Agent: body={body}")
|
59 |
+
|
60 |
+
prompt = body["messages"][-1]["content"]
|
61 |
+
model = body["model"].replace("gaia_agent.", "")
|
62 |
+
|
63 |
+
logger.info(f">>> Gaia Agent: prompt={prompt}, model={model}")
|
64 |
+
|
65 |
+
selected_model = self._get_model_config()
|
66 |
+
|
67 |
+
logger.info(f">>> Gaia Agent: Using model configuration: {selected_model}")
|
68 |
+
|
69 |
+
logger.info(f">>> Gaia Agent Python Path: sys.path={sys.path}")
|
70 |
+
|
71 |
+
llm_provider = selected_model.get("provider")
|
72 |
+
llm_model_name = selected_model.get("model")
|
73 |
+
llm_api_key = selected_model.get("api_key")
|
74 |
+
llm_base_url = selected_model.get("base_url")
|
75 |
+
llm_temperature = selected_model.get("temperature", 0.0)
|
76 |
+
|
77 |
+
from examples.gaia.gaia_agent_runner import GaiaAgentRunner
|
78 |
+
|
79 |
+
mcp_path = os.path.join(
|
80 |
+
os.path.dirname(os.path.abspath(__file__)), "mcp.json"
|
81 |
+
)
|
82 |
+
with open(mcp_path, "r") as f:
|
83 |
+
mcp_config = json.load(f)
|
84 |
+
|
85 |
+
runner = GaiaAgentRunner(
|
86 |
+
llm_provider=llm_provider,
|
87 |
+
llm_model_name=llm_model_name,
|
88 |
+
llm_base_url=llm_base_url,
|
89 |
+
llm_api_key=llm_api_key,
|
90 |
+
llm_temperature=llm_temperature,
|
91 |
+
mcp_config=mcp_config,
|
92 |
+
)
|
93 |
+
|
94 |
+
logger.info(f">>> Gaia Agent: prompt={prompt}, runner={runner}")
|
95 |
+
|
96 |
+
async for i in runner.run(prompt):
|
97 |
+
line = response_line(i, model)
|
98 |
+
logger.info(f">>> Gaia Agent Line: {line}")
|
99 |
+
yield line
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
emsg = traceback.format_exc()
|
103 |
+
logger.error(f">>> Gaia Agent Error: exception {emsg}")
|
104 |
+
yield response_line(f"Gaia Agent Error: {emsg}", model)
|
105 |
+
|
106 |
+
finally:
|
107 |
+
logger.info(f">>> Gaia Agent Done")
|
108 |
+
|
109 |
+
|
110 |
+
import fastapi
|
111 |
+
from fastapi.responses import StreamingResponse
|
112 |
+
|
113 |
+
app = fastapi.FastAPI()
|
114 |
+
|
115 |
+
from examples.gaia.gaia_agent_server import GaiaAgentServer
|
116 |
+
|
117 |
+
agent_server = GaiaAgentServer()
|
118 |
+
|
119 |
+
|
120 |
+
@app.get("/v1/models")
|
121 |
+
async def models():
|
122 |
+
return agent_server.models()
|
123 |
+
|
124 |
+
|
125 |
+
@app.post("/v1/chat/completions")
|
126 |
+
async def chat_completions(request: fastapi.Request):
|
127 |
+
form_data = await request.json()
|
128 |
+
logger.info(f">>> Gaia Agent Server: form_data={form_data}")
|
129 |
+
|
130 |
+
async def event_generator():
|
131 |
+
async for chunk in agent_server.chat_completions(form_data):
|
132 |
+
# Format as SSE: each line needs to start with "data: " and end with two newlines
|
133 |
+
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
134 |
+
|
135 |
+
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
import uvicorn
|
140 |
+
|
141 |
+
uvicorn.run("gaia_agent_server:app", host="0.0.0.0", port=8888)
|
examples/gaia/mcp.json
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"mcpServers": {
|
3 |
+
"e2b-server": {
|
4 |
+
"command": "npx",
|
5 |
+
"args": [
|
6 |
+
"-y",
|
7 |
+
"@e2b/mcp-server"
|
8 |
+
],
|
9 |
+
"env": {
|
10 |
+
"E2B_API_KEY": "${E2B_API_KEY}",
|
11 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"filesystem": {
|
15 |
+
"command": "npx",
|
16 |
+
"args": [
|
17 |
+
"-y",
|
18 |
+
"@modelcontextprotocol/server-filesystem",
|
19 |
+
"${FILESYSTEM_SERVER_WORKDIR}"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
"terminal-controller": {
|
23 |
+
"command": "python",
|
24 |
+
"args": [
|
25 |
+
"-m",
|
26 |
+
"terminal_controller"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
"calculator": {
|
30 |
+
"command": "python",
|
31 |
+
"args": [
|
32 |
+
"-m",
|
33 |
+
"mcp_server_calculator"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"excel": {
|
37 |
+
"command": "npx",
|
38 |
+
"args": [
|
39 |
+
"--yes",
|
40 |
+
"@negokaz/excel-mcp-server"
|
41 |
+
],
|
42 |
+
"env": {
|
43 |
+
"EXCEL_MCP_PAGING_CELLS_LIMIT": "4000"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"playwright": {
|
47 |
+
"command": "npx",
|
48 |
+
"args": [
|
49 |
+
"-y",
|
50 |
+
"@executeautomation/playwright-mcp-server"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
"google-search": {
|
54 |
+
"command": "npx",
|
55 |
+
"args": [
|
56 |
+
"-y",
|
57 |
+
"@adenot/mcp-google-search"
|
58 |
+
],
|
59 |
+
"env": {
|
60 |
+
"GOOGLE_API_KEY": "${GOOGLE_API_KEY}",
|
61 |
+
"GOOGLE_SEARCH_ENGINE_ID": "${GOOGLE_SEARCH_ENGINE_ID}"
|
62 |
+
}
|
63 |
+
},
|
64 |
+
"ms-playwright": {
|
65 |
+
"command": "npx",
|
66 |
+
"args": [
|
67 |
+
"@playwright/mcp@latest",
|
68 |
+
"--no-sandbox",
|
69 |
+
"--headless"
|
70 |
+
],
|
71 |
+
"env": {
|
72 |
+
"PLAYWRIGHT_TIMEOUT": "120000",
|
73 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
74 |
+
}
|
75 |
+
},
|
76 |
+
"audio_server": {
|
77 |
+
"command": "python",
|
78 |
+
"args": [
|
79 |
+
"-m",
|
80 |
+
"mcp_servers.audio_server"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
"image_server": {
|
84 |
+
"command": "python",
|
85 |
+
"args": [
|
86 |
+
"-m",
|
87 |
+
"mcp_servers.image_server"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"youtube_server": {
|
91 |
+
"command": "python",
|
92 |
+
"args": [
|
93 |
+
"-m",
|
94 |
+
"mcp_servers.youtube_server"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
"video_server": {
|
98 |
+
"command": "python",
|
99 |
+
"args": [
|
100 |
+
"-m",
|
101 |
+
"mcp_servers.video_server"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
"search_server": {
|
105 |
+
"command": "python",
|
106 |
+
"args": [
|
107 |
+
"-m",
|
108 |
+
"mcp_servers.search_server"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
"download_server": {
|
112 |
+
"command": "python",
|
113 |
+
"args": [
|
114 |
+
"-m",
|
115 |
+
"mcp_servers.download_server"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
"document_server": {
|
119 |
+
"command": "python",
|
120 |
+
"args": [
|
121 |
+
"-m",
|
122 |
+
"mcp_servers.document_server"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
"browser_server": {
|
126 |
+
"command": "python",
|
127 |
+
"args": [
|
128 |
+
"-m",
|
129 |
+
"mcp_servers.browser_server"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
"reasoning_server": {
|
133 |
+
"command": "python",
|
134 |
+
"args": [
|
135 |
+
"-m",
|
136 |
+
"mcp_servers.reasoning_server"
|
137 |
+
]
|
138 |
+
}
|
139 |
+
}
|
140 |
+
}
|
examples/gaia/prompt.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
system_prompt = """You are an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing, or web browsing, you can handle it all.
|
2 |
+
Please note that the task may be complex. Do not attempt to solve it all at once. You should break the task down and use different tools step by step to solve it. After using each tool, clearly explain the execution results and suggest the next steps.
|
3 |
+
Please utilize appropriate tools for the task, analyze the results obtained from these tools, and provide your reasoning. Always use available tools such as browser, calcutor, etc. to verify correctness rather than relying on your internal knowledge.
|
4 |
+
If you believe the problem has been solved, please output the `final answer`. The `final answer` should be given in <answer></answer> format, while your other thought process should be output in <think></think> tags.
|
5 |
+
Your `final answer` should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
6 |
+
|
7 |
+
Here are some tips to help you give better instructions:
|
8 |
+
<tips>
|
9 |
+
*. Consider search relevant information first using `search_server` tool. Then break down the problem following the instructions from the search.
|
10 |
+
1. Do not use any tools outside of the provided tools list.
|
11 |
+
2. Always use only one tool at a time in each step of your execution.
|
12 |
+
3. Even if the task is complex, there is always a solution. If you can't find the answer using one method, try another approach or use different tools to find the solution.
|
13 |
+
4. Due to context length limitations, always try to complete browser-based tasks with the minimal number of steps possible.
|
14 |
+
5. Before providing the `final answer`, carefully reflect on whether the task has been fully solved. If you have not solved the task, please provide your reasoning and suggest the next steps.
|
15 |
+
6. When providing the `final answer`, answer the user's question directly and precisely. For example, if asked "what animal is x?" and x is a monkey, simply answer "monkey" rather than "x is a monkey".
|
16 |
+
7. When you need to process excel file, prioritize using the `excel` tool instead of writing custom code with `terminal-controller` tool.
|
17 |
+
8. If you need to download a file, please use the `download_server` tool to download the file and save it to the specified path.
|
18 |
+
9. Use the `search_server` to get the relevant website URLs or contents.
|
19 |
+
10. When there are questions related to YouTube video comprehension, use tools in `youtube_server` and `video_server` to analyze the video content by the given question.
|
20 |
+
11. Ensure to call `mcp__reasoning_server__complex_problem_reasoning` for solving complex reasoning tasks, such as riddle, game or competition-level STEM(including code) problems.
|
21 |
+
12. `e2b-server` is a powerful tool only for running **Python** code. Other programming languages are **NOT SUPPORTED**.
|
22 |
+
</tips>
|
23 |
+
|
24 |
+
Now, here is the task. Stay focused and complete it carefully using the appropriate tools!
|
25 |
+
"""
|
26 |
+
# 13. Using `mcp__ms-playwright__browser_take_screenshot` tool to save the screenshot of URLs to the specified path when you need to understand the gif / jpg of the URLs.
|
examples/gaia/run.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import traceback
|
7 |
+
from typing import Any, Dict, List
|
8 |
+
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
from aworld.config.conf import AgentConfig, TaskConfig
|
12 |
+
from aworld.agents.llm_agent import Agent
|
13 |
+
from aworld.core.task import Task
|
14 |
+
from aworld.runner import Runners
|
15 |
+
from examples.gaia.prompt import system_prompt
|
16 |
+
from examples.gaia.utils import (
|
17 |
+
add_file_path,
|
18 |
+
load_dataset_meta,
|
19 |
+
question_scorer,
|
20 |
+
report_results,
|
21 |
+
)
|
22 |
+
|
23 |
+
# Create log directory if it doesn't exist
|
24 |
+
if not os.path.exists(os.getenv("LOG_FILE_PATH")):
|
25 |
+
os.makedirs(os.getenv("LOG_FILE_PATH"))
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument(
|
29 |
+
"--start",
|
30 |
+
type=int,
|
31 |
+
default=0,
|
32 |
+
help="Start index of the dataset",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--end",
|
36 |
+
type=int,
|
37 |
+
default=20,
|
38 |
+
help="End index of the dataset",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--q",
|
42 |
+
type=str,
|
43 |
+
help="Question Index, e.g., 0-0-0-0-0. Highest priority: override other arguments if provided.",
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--skip",
|
47 |
+
action="store_true",
|
48 |
+
help="Skip the question if it has been processed before.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--split",
|
52 |
+
type=str,
|
53 |
+
default="validation",
|
54 |
+
help="Split of the dataset, e.g., validation, test",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--blacklist_file_path",
|
58 |
+
type=str,
|
59 |
+
nargs="?",
|
60 |
+
help="Blacklist file path, e.g., blacklist.txt",
|
61 |
+
)
|
62 |
+
args = parser.parse_args()
|
63 |
+
|
64 |
+
|
65 |
+
def setup_logging():
|
66 |
+
logging_logger = logging.getLogger()
|
67 |
+
logging_logger.setLevel(logging.INFO)
|
68 |
+
|
69 |
+
log_file_name = (
|
70 |
+
f"/super_agent_{args.q}.log"
|
71 |
+
if args.q
|
72 |
+
else f"/super_agent_{args.start}_{args.end}.log"
|
73 |
+
)
|
74 |
+
file_handler = logging.FileHandler(
|
75 |
+
os.getenv(
|
76 |
+
"LOG_FILE_PATH",
|
77 |
+
"run_super_agent.log",
|
78 |
+
)
|
79 |
+
+ log_file_name,
|
80 |
+
mode="a",
|
81 |
+
encoding="utf-8",
|
82 |
+
)
|
83 |
+
file_handler.setLevel(logging.INFO)
|
84 |
+
|
85 |
+
formatter = logging.Formatter(
|
86 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
87 |
+
)
|
88 |
+
file_handler.setFormatter(formatter)
|
89 |
+
|
90 |
+
logging_logger.addHandler(file_handler)
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
load_dotenv()
|
95 |
+
setup_logging()
|
96 |
+
|
97 |
+
gaia_dataset_path = os.getenv("GAIA_DATASET_PATH", "./gaia_dataset")
|
98 |
+
full_dataset = load_dataset_meta(gaia_dataset_path, split=args.split)
|
99 |
+
logging.info(f"Total questions: {len(full_dataset)}")
|
100 |
+
|
101 |
+
agent_config = AgentConfig(
|
102 |
+
llm_provider="openai",
|
103 |
+
llm_model_name=os.getenv("LLM_MODEL_NAME", "gpt-4o"),
|
104 |
+
llm_api_key=os.getenv("LLM_API_KEY", "your_openai_api_key"),
|
105 |
+
llm_base_url=os.getenv("LLM_BASE_URL", "your_openai_base_url"),
|
106 |
+
)
|
107 |
+
super_agent = Agent(
|
108 |
+
conf=agent_config,
|
109 |
+
name="gaia_super_agent",
|
110 |
+
system_prompt=system_prompt,
|
111 |
+
mcp_servers=[
|
112 |
+
"e2b-server",
|
113 |
+
# "filesystem",
|
114 |
+
"terminal-controller",
|
115 |
+
"excel",
|
116 |
+
"calculator",
|
117 |
+
"ms-playwright",
|
118 |
+
"audio_server",
|
119 |
+
"image_server",
|
120 |
+
"video_server",
|
121 |
+
"search_server",
|
122 |
+
"download_server",
|
123 |
+
"document_server",
|
124 |
+
# "browser_server",
|
125 |
+
"youtube_server",
|
126 |
+
"reasoning_server",
|
127 |
+
],
|
128 |
+
)
|
129 |
+
|
130 |
+
# load results from the checkpoint file
|
131 |
+
if os.path.exists(os.getenv("LOG_FILE_PATH") + "/results.json"):
|
132 |
+
with open(
|
133 |
+
os.getenv("LOG_FILE_PATH") + "/results.json", "r", encoding="utf-8"
|
134 |
+
) as results_f:
|
135 |
+
results: List[Dict[str, Any]] = json.load(results_f)
|
136 |
+
else:
|
137 |
+
results: List[Dict[str, Any]] = []
|
138 |
+
|
139 |
+
# load blacklist `task_id`
|
140 |
+
if args.blacklist_file_path and os.path.exists(args.blacklist_file_path):
|
141 |
+
with open(args.blacklist_file_path, "r", encoding="utf-8") as f:
|
142 |
+
blacklist = set(f.read().splitlines())
|
143 |
+
else:
|
144 |
+
blacklist = set() # Empty set if file doesn't exist
|
145 |
+
|
146 |
+
try:
|
147 |
+
# slice dataset by args.start and args.end, overrided by args.q (single `task_id`)
|
148 |
+
dataset_slice = (
|
149 |
+
[
|
150 |
+
dataset_record
|
151 |
+
for idx, dataset_record in enumerate(full_dataset)
|
152 |
+
if dataset_record["task_id"] in args.q
|
153 |
+
]
|
154 |
+
if args.q is not None
|
155 |
+
else full_dataset[args.start : args.end]
|
156 |
+
)
|
157 |
+
|
158 |
+
# main loop to execute questions
|
159 |
+
for i, dataset_i in enumerate(dataset_slice):
|
160 |
+
# specify `task_id`
|
161 |
+
if args.q and args.q != dataset_i["task_id"]:
|
162 |
+
continue
|
163 |
+
# only valid for args.q==None
|
164 |
+
if not args.q:
|
165 |
+
# blacklist
|
166 |
+
if dataset_i["task_id"] in blacklist:
|
167 |
+
continue
|
168 |
+
|
169 |
+
# pass
|
170 |
+
if any(
|
171 |
+
# Question Done and Correct
|
172 |
+
(result["task_id"] == dataset_i["task_id"] and result["is_correct"])
|
173 |
+
for result in results
|
174 |
+
) or any(
|
175 |
+
# Question Done and Incorrect, but Level is 3
|
176 |
+
(
|
177 |
+
result["task_id"] == dataset_i["task_id"]
|
178 |
+
and not result["is_correct"]
|
179 |
+
and dataset_i["Level"] == 3
|
180 |
+
)
|
181 |
+
for result in results
|
182 |
+
):
|
183 |
+
continue
|
184 |
+
|
185 |
+
# skip
|
186 |
+
if args.skip and any(
|
187 |
+
# Question Done and Correct
|
188 |
+
(result["task_id"] == dataset_i["task_id"])
|
189 |
+
for result in results
|
190 |
+
):
|
191 |
+
continue
|
192 |
+
|
193 |
+
# run
|
194 |
+
try:
|
195 |
+
logging.info(f"Start to process: {dataset_i['task_id']}")
|
196 |
+
logging.info(f"Detail: {dataset_i}")
|
197 |
+
logging.info(f"Question: {dataset_i['Question']}")
|
198 |
+
logging.info(f"Level: {dataset_i['Level']}")
|
199 |
+
logging.info(f"Tools: {dataset_i['Annotator Metadata']['Tools']}")
|
200 |
+
|
201 |
+
question = add_file_path(
|
202 |
+
dataset_i, file_path=gaia_dataset_path, split=args.split
|
203 |
+
)["Question"]
|
204 |
+
|
205 |
+
task = Task(input=question, agent=super_agent, conf=TaskConfig())
|
206 |
+
result = Runners.sync_run_task(task=task)
|
207 |
+
|
208 |
+
match = re.search(r"<answer>(.*?)</answer>", result[task.id].get('answer'))
|
209 |
+
if match:
|
210 |
+
answer = match.group(1)
|
211 |
+
logging.info(f"Agent answer: {answer}")
|
212 |
+
logging.info(f"Correct answer: {dataset_i['Final answer']}")
|
213 |
+
|
214 |
+
if question_scorer(answer, dataset_i["Final answer"]):
|
215 |
+
logging.info(f"Question {i} Correct!")
|
216 |
+
else:
|
217 |
+
logging.info("Incorrect!")
|
218 |
+
|
219 |
+
# Create the new result record
|
220 |
+
new_result = {
|
221 |
+
"task_id": dataset_i["task_id"],
|
222 |
+
"level": dataset_i["Level"],
|
223 |
+
"question": question,
|
224 |
+
"answer": dataset_i["Final answer"],
|
225 |
+
"response": answer,
|
226 |
+
"is_correct": question_scorer(answer, dataset_i["Final answer"]),
|
227 |
+
}
|
228 |
+
|
229 |
+
# Check if this task_id already exists in results
|
230 |
+
existing_index = next(
|
231 |
+
(
|
232 |
+
i
|
233 |
+
for i, result in enumerate(results)
|
234 |
+
if result["task_id"] == dataset_i["task_id"]
|
235 |
+
),
|
236 |
+
None,
|
237 |
+
)
|
238 |
+
|
239 |
+
if existing_index is not None:
|
240 |
+
# Update existing record
|
241 |
+
results[existing_index] = new_result
|
242 |
+
logging.info(
|
243 |
+
f"Updated existing record for task_id: {dataset_i['task_id']}"
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
# Append new record
|
247 |
+
results.append(new_result)
|
248 |
+
logging.info(
|
249 |
+
f"Added new record for task_id: {dataset_i['task_id']}"
|
250 |
+
)
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
logging.error(f"Error processing {i}: {traceback.format_exc()}")
|
254 |
+
continue
|
255 |
+
except KeyboardInterrupt:
|
256 |
+
pass
|
257 |
+
finally:
|
258 |
+
# report
|
259 |
+
report_results(results)
|
260 |
+
with open(
|
261 |
+
os.getenv("LOG_FILE_PATH") + "/results.json", "w", encoding="utf-8"
|
262 |
+
) as f:
|
263 |
+
json.dump(results, f, indent=4, ensure_ascii=False)
|
examples/gaia/stream_aworld_ui.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import time
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from time import sleep
|
5 |
+
|
6 |
+
from rich.table import Table
|
7 |
+
|
8 |
+
from aworld.output.utils import consume_content
|
9 |
+
from rich.status import Status
|
10 |
+
from rich.console import Console
|
11 |
+
|
12 |
+
from aworld.output import MessageOutput, WorkSpace
|
13 |
+
from aworld.output.base import StepOutput, ToolResultOutput
|
14 |
+
from aworld.output.ui.base import AworldUI
|
15 |
+
import json
|
16 |
+
@dataclass
|
17 |
+
class StreamAworldUI(AworldUI):
|
18 |
+
console: Console = field(default_factory=Console)
|
19 |
+
status: Status = None
|
20 |
+
workspace: WorkSpace = None
|
21 |
+
gaia_output_line_tag = "GA_FMT_CONTENT:"
|
22 |
+
|
23 |
+
async def message_output(self, __output__: MessageOutput):
|
24 |
+
result = []
|
25 |
+
async def __log_item(item):
|
26 |
+
result.append(item)
|
27 |
+
self.console.print(f"{self.gaia_output_line_tag} {json.dumps(item)}", end="")
|
28 |
+
|
29 |
+
if __output__.reason_generator or __output__.response_generator:
|
30 |
+
if __output__.reason_generator:
|
31 |
+
await consume_content(__output__.reason_generator, __log_item)
|
32 |
+
if __output__.reason_generator:
|
33 |
+
await consume_content(__output__.response_generator, __log_item)
|
34 |
+
else:
|
35 |
+
await consume_content(__output__.reasoning, __log_item)
|
36 |
+
await consume_content(__output__.response, __log_item)
|
37 |
+
# if __output__.tool_calls:
|
38 |
+
# await consume_content(__output__.tool_calls, __log_item)
|
39 |
+
self.console.print("")
|
40 |
+
|
41 |
+
async def tool_result(self, output: ToolResultOutput):
|
42 |
+
"""
|
43 |
+
tool_result
|
44 |
+
"""
|
45 |
+
table = Table(show_header=False, header_style="bold magenta", title=f"Call Tools#ID_{output.origin_tool_call.id}")
|
46 |
+
table.add_column("name", style="dim", width=12)
|
47 |
+
table.add_column("content")
|
48 |
+
table.add_row("function_name", output.origin_tool_call.function.name)
|
49 |
+
table.add_row("arguments", output.origin_tool_call.function.arguments)
|
50 |
+
table.add_row("result", output.data)
|
51 |
+
self.console.print(table)
|
52 |
+
|
53 |
+
async def step(self, output: StepOutput):
|
54 |
+
if output.status == "START":
|
55 |
+
self.console.print(f"[bold green]{output.name} ✈️START ...")
|
56 |
+
self.status = self.console.status(f"[bold green]{output.name} RUNNING ...")
|
57 |
+
self.status.start()
|
58 |
+
elif output.status == "FINISHED":
|
59 |
+
self.status.stop()
|
60 |
+
self.console.print(f"[bold green]{output.name} 🛬FINISHED ...")
|
61 |
+
elif output.status == "FAILED":
|
62 |
+
self.status.stop()
|
63 |
+
self.console.print(f"[bold red]{output.name} 💥FAILED ...")
|
64 |
+
else:
|
65 |
+
self.status.stop()
|
66 |
+
self.console.print(f"============={output.name} ❓❓❓UNKNOWN#{output.status} ======================")
|
examples/gaia/utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List, Optional
|
6 |
+
import logging
|
7 |
+
from tabulate import tabulate
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def normalize_str(input_str, remove_punct=True) -> str:
|
13 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
14 |
+
if remove_punct:
|
15 |
+
translator = str.maketrans("", "", string.punctuation)
|
16 |
+
return no_spaces.lower().translate(translator)
|
17 |
+
else:
|
18 |
+
return no_spaces.lower()
|
19 |
+
|
20 |
+
|
21 |
+
def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]:
|
22 |
+
if char_list is None:
|
23 |
+
char_list = [",", ";"]
|
24 |
+
pattern = f"[{''.join(char_list)}]"
|
25 |
+
return re.split(pattern, s)
|
26 |
+
|
27 |
+
|
28 |
+
def normalize_number_str(number_str: str) -> float:
|
29 |
+
for char in ["$", "%", ","]:
|
30 |
+
number_str = number_str.replace(char, "")
|
31 |
+
try:
|
32 |
+
return float(number_str)
|
33 |
+
except ValueError:
|
34 |
+
logger.error(f"String {number_str} cannot be normalized to number str.")
|
35 |
+
return float("inf")
|
36 |
+
|
37 |
+
|
38 |
+
def question_scorer(model_answer: str, ground_truth: str) -> bool:
|
39 |
+
def is_float(element: Any) -> bool:
|
40 |
+
try:
|
41 |
+
float(element)
|
42 |
+
return True
|
43 |
+
except ValueError:
|
44 |
+
return False
|
45 |
+
|
46 |
+
try:
|
47 |
+
if is_float(ground_truth):
|
48 |
+
logger.info(f"Evaluating {model_answer} as a number.")
|
49 |
+
normalized_answer = normalize_number_str(model_answer)
|
50 |
+
return normalized_answer == float(ground_truth)
|
51 |
+
|
52 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
53 |
+
logger.info(f"Evaluating {model_answer} as a comma separated list.")
|
54 |
+
gt_elems = split_string(ground_truth)
|
55 |
+
ma_elems = split_string(model_answer)
|
56 |
+
|
57 |
+
if len(gt_elems) != len(ma_elems):
|
58 |
+
logger.warning("Answer lists have different lengths, returning False.")
|
59 |
+
return False
|
60 |
+
|
61 |
+
comparisons = []
|
62 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
63 |
+
if is_float(gt_elem):
|
64 |
+
normalized_ma_elem = normalize_number_str(ma_elem)
|
65 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
66 |
+
else:
|
67 |
+
ma_elem = normalize_str(ma_elem, remove_punct=False)
|
68 |
+
gt_elem = normalize_str(gt_elem, remove_punct=False)
|
69 |
+
comparisons.append(ma_elem == gt_elem)
|
70 |
+
return all(comparisons)
|
71 |
+
else:
|
72 |
+
logger.info(f"Evaluating {model_answer} as a string.")
|
73 |
+
ma_elem = normalize_str(model_answer)
|
74 |
+
gt_elem = normalize_str(ground_truth)
|
75 |
+
return ma_elem == gt_elem
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error during evaluation: {e}")
|
78 |
+
return False
|
79 |
+
|
80 |
+
|
81 |
+
def load_dataset_meta(path: str, split: str = "validation"):
|
82 |
+
data_dir = Path(path) / split
|
83 |
+
|
84 |
+
dataset = []
|
85 |
+
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
|
86 |
+
lines = metaf.readlines()
|
87 |
+
for line in lines:
|
88 |
+
data = json.loads(line)
|
89 |
+
if data["task_id"] == "0-0-0-0-0":
|
90 |
+
continue
|
91 |
+
if data["file_name"]:
|
92 |
+
data["file_name"] = data_dir / data["file_name"]
|
93 |
+
dataset.append(data)
|
94 |
+
return dataset
|
95 |
+
|
96 |
+
|
97 |
+
def load_dataset_meta_dict(path: str, split: str = "validation"):
|
98 |
+
data_dir = Path(path) / split
|
99 |
+
|
100 |
+
dataset = {}
|
101 |
+
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
|
102 |
+
lines = metaf.readlines()
|
103 |
+
for line in lines:
|
104 |
+
data = json.loads(line)
|
105 |
+
if data["task_id"] == "0-0-0-0-0":
|
106 |
+
continue
|
107 |
+
if data["file_name"]:
|
108 |
+
data["file_name"] = data_dir / data["file_name"]
|
109 |
+
dataset[data["task_id"]] = data
|
110 |
+
return dataset
|
111 |
+
|
112 |
+
|
113 |
+
def add_file_path(
|
114 |
+
task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation"
|
115 |
+
):
|
116 |
+
if task["file_name"]:
|
117 |
+
file_path = Path(f"{file_path}/{split}") / task["file_name"]
|
118 |
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
119 |
+
task["Question"] += f" Here are the necessary document files: {file_path}"
|
120 |
+
|
121 |
+
elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
|
122 |
+
task["Question"] += f" Here are the necessary image files: {file_path}"
|
123 |
+
|
124 |
+
elif file_path.suffix in [".xlsx", "xls", ".csv"]:
|
125 |
+
task["Question"] += (
|
126 |
+
f" Here are the necessary table files: {file_path}, for processing excel file,"
|
127 |
+
" you can use the excel tool or write python code to process the file"
|
128 |
+
" step-by-step and get the information."
|
129 |
+
)
|
130 |
+
elif file_path.suffix in [".py"]:
|
131 |
+
task["Question"] += f" Here are the necessary python files: {file_path}"
|
132 |
+
|
133 |
+
else:
|
134 |
+
task["Question"] += f" Here are the necessary files: {file_path}"
|
135 |
+
|
136 |
+
return task
|
137 |
+
|
138 |
+
|
139 |
+
def report_results(entries):
|
140 |
+
# Initialize counters
|
141 |
+
total_entries = len(entries)
|
142 |
+
total_correct = 0
|
143 |
+
|
144 |
+
# Initialize level statistics
|
145 |
+
level_stats = {}
|
146 |
+
|
147 |
+
# Process each entry
|
148 |
+
for entry in entries:
|
149 |
+
level = entry.get("level")
|
150 |
+
is_correct = entry.get("is_correct", False)
|
151 |
+
|
152 |
+
# Initialize level stats if not already present
|
153 |
+
if level not in level_stats:
|
154 |
+
level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0}
|
155 |
+
|
156 |
+
# Update counters
|
157 |
+
level_stats[level]["total"] += 1
|
158 |
+
if is_correct:
|
159 |
+
total_correct += 1
|
160 |
+
level_stats[level]["correct"] += 1
|
161 |
+
|
162 |
+
# Calculate accuracy for each level
|
163 |
+
for level, stats in level_stats.items():
|
164 |
+
if stats["total"] > 0:
|
165 |
+
stats["accuracy"] = (stats["correct"] / stats["total"]) * 100
|
166 |
+
|
167 |
+
# Print overall statistics with colorful logging
|
168 |
+
logger.info("Overall Statistics:")
|
169 |
+
overall_accuracy = (total_correct / total_entries) * 100
|
170 |
+
|
171 |
+
# Create overall statistics table
|
172 |
+
overall_table = [
|
173 |
+
["Total Entries", total_entries],
|
174 |
+
["Total Correct", total_correct],
|
175 |
+
["Overall Accuracy", f"{overall_accuracy:.2f}%"],
|
176 |
+
]
|
177 |
+
logger.success(tabulate(overall_table, tablefmt="grid"))
|
178 |
+
logger.info("")
|
179 |
+
|
180 |
+
# Create level statistics table
|
181 |
+
logger.info("Statistics by Level:")
|
182 |
+
level_table = []
|
183 |
+
headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"]
|
184 |
+
|
185 |
+
for level in sorted(level_stats.keys()):
|
186 |
+
stats = level_stats[level]
|
187 |
+
level_table.append(
|
188 |
+
[level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"]
|
189 |
+
)
|
190 |
+
|
191 |
+
logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))
|