Duibonduil commited on
Commit
900b15b
·
verified ·
1 Parent(s): d8445e5

Upload 4 files

Browse files
AWorld-main/aworlddistributed/aworldspace/utils/loader.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import json
3
+ import logging
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ import traceback
8
+ from aworldspace.base import AGENT_SPACE
9
+ import aworld.trace as trace # noqa
10
+
11
+ from config import AGENTS_DIR
12
+
13
+ if not os.path.exists(AGENTS_DIR):
14
+ os.makedirs(AGENTS_DIR)
15
+
16
+ PIPELINES = {}
17
+ PIPELINE_MODULES = {}
18
+
19
+ def get_all_pipelines():
20
+ pipelines = {}
21
+ for pipeline_id in PIPELINE_MODULES.keys():
22
+ pipeline = PIPELINE_MODULES[pipeline_id]
23
+
24
+ if hasattr(pipeline, "type"):
25
+ if pipeline.type == "manifold":
26
+ manifold_pipelines = []
27
+
28
+ # Check if pipelines is a function or a list
29
+ if callable(pipeline.pipelines):
30
+ manifold_pipelines = pipeline.pipelines()
31
+ else:
32
+ manifold_pipelines = pipeline.pipelines
33
+
34
+ for p in manifold_pipelines:
35
+ manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
36
+
37
+ manifold_pipeline_name = p["name"]
38
+ if hasattr(pipeline, "name"):
39
+ manifold_pipeline_name = (
40
+ f"{pipeline.name}{manifold_pipeline_name}"
41
+ )
42
+
43
+ pipelines[manifold_pipeline_id] = {
44
+ "module": pipeline_id,
45
+ "type": pipeline.type if hasattr(pipeline, "type") else "pipe",
46
+ "id": manifold_pipeline_id,
47
+ "name": manifold_pipeline_name,
48
+ "valves": (
49
+ pipeline.valves if hasattr(pipeline, "valves") else None
50
+ ),
51
+ }
52
+ if pipeline.type == "filter":
53
+ pipelines[pipeline_id] = {
54
+ "module": pipeline_id,
55
+ "type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
56
+ "id": pipeline_id,
57
+ "name": (
58
+ pipeline.name if hasattr(pipeline, "name") else pipeline_id
59
+ ),
60
+ "pipelines": (
61
+ pipeline.valves.pipelines
62
+ if hasattr(pipeline, "valves")
63
+ and hasattr(pipeline.valves, "pipelines")
64
+ else []
65
+ ),
66
+ "priority": (
67
+ pipeline.valves.priority
68
+ if hasattr(pipeline, "valves")
69
+ and hasattr(pipeline.valves, "priority")
70
+ else 0
71
+ ),
72
+ "valves": pipeline.valves if hasattr(pipeline, "valves") else None,
73
+ }
74
+ else:
75
+ pipelines[pipeline_id] = {
76
+ "module": pipeline_id,
77
+ "type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
78
+ "id": pipeline_id,
79
+ "name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
80
+ "valves": pipeline.valves if hasattr(pipeline, "valves") else None,
81
+ }
82
+
83
+ return pipelines
84
+
85
+
86
+ def parse_frontmatter(content):
87
+ frontmatter = {}
88
+ for line in content.split("\n"):
89
+ if ":" in line:
90
+ key, value = line.split(":", 1)
91
+ frontmatter[key.strip().lower()] = value.strip()
92
+ return frontmatter
93
+
94
+
95
+ def install_frontmatter_requirements(requirements):
96
+ if requirements:
97
+ req_list = [req.strip() for req in requirements.split(",")]
98
+ for req in req_list:
99
+ print(f"Installing requirement: {req}")
100
+ subprocess.check_call([sys.executable, "-m", "pip", "install", req])
101
+ else:
102
+ print("No requirements found in frontmatter.")
103
+
104
+
105
+ async def load_module_from_path(module_name, module_path):
106
+
107
+ try:
108
+ # Read the module content
109
+ with open(module_path, "r") as file:
110
+ content = file.read()
111
+
112
+ # Parse frontmatter
113
+ frontmatter = {}
114
+ if content.startswith('"""'):
115
+ end = content.find('"""', 3)
116
+ if end != -1:
117
+ frontmatter_content = content[3:end]
118
+ frontmatter = parse_frontmatter(frontmatter_content)
119
+
120
+ # Install requirements if specified
121
+ if "requirements" in frontmatter:
122
+ install_frontmatter_requirements(frontmatter["requirements"])
123
+
124
+ # Load the module
125
+ spec = importlib.util.spec_from_file_location(module_name, module_path)
126
+ module = importlib.util.module_from_spec(spec)
127
+ spec.loader.exec_module(module)
128
+ logging.info(f"Loaded module start: {module.__name__}")
129
+ if hasattr(module, "Pipeline"):
130
+ return module.Pipeline()
131
+ else:
132
+ logging.info(f"Loaded module failed: {module.__name__ } No Pipeline class found")
133
+ raise Exception("No Pipeline class found")
134
+ except Exception as e:
135
+ logging.info(f"Error loading module: {module_name}, error is {e}")
136
+ traceback.print_exc()
137
+ # Move the file to the error folder
138
+ failed_pipelines_folder = os.path.join(AGENTS_DIR, "failed")
139
+ if not os.path.exists(failed_pipelines_folder):
140
+ os.makedirs(failed_pipelines_folder)
141
+
142
+ # failed_file_path = os.path.join(failed_pipelines_folder, f"{module_name}.py")
143
+ # if module_path.__contains__(PIPELINES_DIR):
144
+ # os.rename(module_path, failed_file_path)
145
+ print(e)
146
+ return None
147
+
148
+
149
+ async def load_modules_from_directory(directory):
150
+ logging.info(f"load_modules_from_directory: {directory}")
151
+ global PIPELINE_MODULES
152
+
153
+ for filename in os.listdir(directory):
154
+ if filename.endswith(".py"):
155
+ module_name = filename[:-3] # Remove the .py extension
156
+ module_path = os.path.join(directory, filename)
157
+
158
+ # Create subfolder matching the filename without the .py extension
159
+ subfolder_path = os.path.join(directory, module_name)
160
+ if not os.path.exists(subfolder_path):
161
+ os.makedirs(subfolder_path)
162
+ logging.info(f"Created subfolder: {subfolder_path}")
163
+
164
+ # Create a valves.json file if it doesn't exist
165
+ valves_json_path = os.path.join(subfolder_path, "valves.json")
166
+ if not os.path.exists(valves_json_path):
167
+ with open(valves_json_path, "w") as f:
168
+ json.dump({}, f)
169
+ logging.info(f"Created valves.json in: {subfolder_path}")
170
+
171
+ pipeline = await load_module_from_path(module_name, module_path)
172
+ if pipeline:
173
+ # Overwrite pipeline.valves with values from valves.json
174
+ if os.path.exists(valves_json_path):
175
+ with open(valves_json_path, "r") as f:
176
+ valves_json = json.load(f)
177
+ if hasattr(pipeline, "valves"):
178
+ ValvesModel = pipeline.valves.__class__
179
+ # Create a ValvesModel instance using default values and overwrite with valves_json
180
+ combined_valves = {
181
+ **pipeline.valves.model_dump(),
182
+ **valves_json,
183
+ }
184
+ valves = ValvesModel(**combined_valves)
185
+ pipeline.valves = valves
186
+
187
+ logging.info(f"Updated valves for module: {module_name}")
188
+
189
+ pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
190
+ PIPELINE_MODULES[pipeline_id] = pipeline
191
+
192
+ logging.info(f"Loaded module success: {module_name}")
193
+ else:
194
+ logging.warning(f"No Pipeline class found in {module_name}")
195
+
196
+ AGENT_SPACE.agent_modules = PIPELINE_MODULES
197
+ AGENT_SPACE.agents_meta = get_all_pipelines()
AWorld-main/aworlddistributed/aworldspace/utils/log.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+
5
+ from aworld.models.model_response import ModelResponse
6
+
7
+ from base import AworldTask, AworldTaskResult
8
+ from config import ROOT_LOG
9
+
10
+
11
+ class TaskLogger:
12
+ """任务提交日志记录器"""
13
+
14
+ def __init__(self, log_file: str = "aworld_task_submissions.log"):
15
+ self.log_file = os.path.join(ROOT_LOG, 'task_logs' , log_file)
16
+ self._ensure_log_file_exists()
17
+
18
+ def _ensure_log_file_exists(self):
19
+ """确保日志文件存在"""
20
+ if not os.path.exists(self.log_file):
21
+ os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
22
+ with open(self.log_file, 'w', encoding='utf-8') as f:
23
+ f.write("# Aworld Task Submission Log\n")
24
+ f.write(
25
+ "# Format: [timestamp] task_id | agent_id | server | status | agent_answer | correct_answer | is_correct | details\n\n")
26
+
27
+ def log_task_submission(self, task: AworldTask, status: str, details: str = "",
28
+ task_result: AworldTaskResult = None):
29
+ """记录任务提交日志"""
30
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
31
+ log_entry = f"[{timestamp}] {task.task_id} | {task.agent_id} | {task.node_id} | {status} | {task_result.data.get('agent_answer') if task_result and task_result.data else None} | {task_result.data.get('correct_answer') if task_result and task_result.data else None} | {task_result.data.get('gaia_correct') if task_result and task_result.data else None} |{details}\n"
32
+
33
+ try:
34
+ with open(self.log_file, 'a', encoding='utf-8') as f:
35
+ f.write(log_entry)
36
+ except Exception as e:
37
+ logging.error(f"Failed to write task submission log: {e}")
38
+
39
+ def log_task_result(self, task: AworldTask, result: ModelResponse):
40
+ try:
41
+ date_str = datetime.now().strftime("%Y%m%d")
42
+ result_dir = os.path.join(ROOT_LOG, 'task_logs', 'result', date_str)
43
+ os.makedirs(result_dir, exist_ok=True)
44
+
45
+ md_file = f"{result_dir}/{task.task_id}.md"
46
+
47
+ content_parts = []
48
+ if hasattr(result, 'content') and result.content:
49
+ if isinstance(result.content, list):
50
+ content_parts.extend(result.content)
51
+ else:
52
+ content_parts.append(str(result.content))
53
+
54
+ file_exists = os.path.exists(md_file)
55
+ with open(md_file, 'a', encoding='utf-8') as f:
56
+ if not file_exists:
57
+ f.write(f"# Task Result: {task.task_id}\n\n")
58
+ f.write(f"**Agent ID:** {task.agent_id}\n\n")
59
+ f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
60
+ f.write("## Content\n\n")
61
+
62
+ if content_parts:
63
+ for i, content in enumerate(content_parts, 1):
64
+ f.write(f"{content}\n\n")
65
+ else:
66
+ f.write("No content available.\n\n")
67
+
68
+ return md_file
69
+
70
+ except Exception as e:
71
+ logging.error(f"Failed to write task result log: {e}")
72
+ return None
73
+
74
+
75
+ task_logger = TaskLogger(log_file=f"aworld_task_submissions_{datetime.now().strftime('%Y%m%d')}.log")
AWorld-main/aworlddistributed/aworldspace/utils/mcp_utils.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def load_all_mcp_config():
5
+ return {
6
+ "mcpServers": {
7
+ "e2b-server": {
8
+ "command": "npx",
9
+ "args": [
10
+ "-y",
11
+ "@e2b/mcp-server"
12
+ ],
13
+ "env": {
14
+ "E2B_API_KEY": os.environ["E2B_API_KEY"],
15
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
16
+ }
17
+ },
18
+ "filesystem": {
19
+ "command": "npx",
20
+ "args": [
21
+ "-y",
22
+ "@modelcontextprotocol/server-filesystem",
23
+ "${FILESYSTEM_SERVER_WORKDIR}"
24
+ ]
25
+ },
26
+ "terminal-controller": {
27
+ "command": "python",
28
+ "args": [
29
+ "-m",
30
+ "terminal_controller"
31
+ ],
32
+ "env": {
33
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "300"
34
+ }
35
+ },
36
+ "calculator": {
37
+ "command": "python",
38
+ "args": [
39
+ "-m",
40
+ "mcp_server_calculator"
41
+ ],
42
+ "env": {
43
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "20"
44
+ }
45
+ },
46
+ "excel": {
47
+ "command": "uvx",
48
+ "args": ["excel-mcp-server", "stdio"],
49
+ "env": {
50
+ "EXCEL_MCP_PAGING_CELLS_LIMIT": "4000",
51
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
52
+ }
53
+ },
54
+ "google-search": {
55
+ "command": "npx",
56
+ "args": [
57
+ "-y",
58
+ "@adenot/mcp-google-search"
59
+ ],
60
+ "env": {
61
+ "GOOGLE_API_KEY": os.environ["GOOGLE_API_KEY"],
62
+ "GOOGLE_SEARCH_ENGINE_ID": os.environ["GOOGLE_CSE_ID"],
63
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
64
+ }
65
+ },
66
+ "ms-playwright": {
67
+ "command": "npx",
68
+ "args": [
69
+ "@playwright/mcp@latest",
70
+ "--no-sandbox",
71
+ "--headless",
72
+ "--isolated"
73
+ ],
74
+ "env": {
75
+ "PLAYWRIGHT_TIMEOUT": "120000",
76
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
77
+ }
78
+ },
79
+ "audio_server": {
80
+ "command": "python",
81
+ "args": [
82
+ "-m",
83
+ "mcp_servers.audio_server"
84
+ ],
85
+ "env": {
86
+ "AUDIO_LLM_API_KEY": os.environ["AUDIO_LLM_API_KEY"],
87
+ "AUDIO_LLM_BASE_URL": os.environ["AUDIO_LLM_BASE_URL"],
88
+ "AUDIO_LLM_MODEL_NAME": os.environ["AUDIO_LLM_MODEL_NAME"],
89
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
90
+ }
91
+ },
92
+ "image_server": {
93
+ "command": "python",
94
+ "args": [
95
+ "-m",
96
+ "mcp_servers.image_server"
97
+ ],
98
+ "env": {
99
+ "LLM_API_KEY": os.environ.get("LLM_API_KEY"),
100
+ "LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
101
+ "LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
102
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
103
+ }
104
+ },
105
+ "youtube_server": {
106
+ "command": "python",
107
+ "args": [
108
+ "-m",
109
+ "mcp_servers.youtube_server"
110
+ ],
111
+ "env": {
112
+ "CHROME_DRIVER_PATH": os.environ['CHROME_DRIVER_PATH'],
113
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
114
+ }
115
+ },
116
+ "video_server": {
117
+ "command": "python",
118
+ "args": [
119
+ "-m",
120
+ "mcp_servers.video_server"
121
+ ],
122
+ "env": {
123
+ "LLM_API_KEY": os.environ.get("LLM_API_KEY"),
124
+ "LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
125
+ "LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
126
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
127
+ }
128
+ },
129
+ "search_server": {
130
+ "command": "python",
131
+ "args": [
132
+ "-m",
133
+ "mcp_servers.search_server"
134
+ ],
135
+ "env": {
136
+ "GOOGLE_API_KEY": os.environ["GOOGLE_API_KEY"],
137
+ "GOOGLE_CSE_ID": os.environ["GOOGLE_CSE_ID"],
138
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "60"
139
+ }
140
+ },
141
+ "download_server": {
142
+ "command": "python",
143
+ "args": [
144
+ "-m",
145
+ "mcp_servers.download_server"
146
+ ],
147
+ "env": {
148
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
149
+ }
150
+ },
151
+ "document_server": {
152
+ "command": "python",
153
+ "args": [
154
+ "-m",
155
+ "mcp_servers.document_server"
156
+ ],
157
+ "env": {
158
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
159
+ }
160
+ },
161
+ "browser_server": {
162
+ "command": "python",
163
+ "args": [
164
+ "-m",
165
+ "mcp_servers.browser_server"
166
+ ],
167
+ "env": {
168
+ "LLM_API_KEY": os.environ.get("LLM_API_KEY"),
169
+ "LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
170
+ "LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
171
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
172
+ }
173
+ },
174
+ "reasoning_server": {
175
+ "command": "python",
176
+ "args": [
177
+ "-m",
178
+ "mcp_servers.reasoning_server"
179
+ ],
180
+ "env": {
181
+ "LLM_API_KEY": os.environ.get("LLM_API_KEY"),
182
+ "LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
183
+ "LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
184
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
185
+ }
186
+ },
187
+ "e2b-code-server": {
188
+ "command": "python",
189
+ "args": [
190
+ "-m",
191
+ "mcp_servers.e2b_code_server"
192
+ ],
193
+ "env": {
194
+ "E2B_API_KEY": os.environ["E2B_API_KEY"],
195
+ "SESSION_REQUEST_CONNECT_TIMEOUT": "120"
196
+ }
197
+ },
198
+ }
199
+ }
AWorld-main/aworlddistributed/aworldspace/utils/utils.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import string
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from loguru import logger
8
+ from tabulate import tabulate
9
+
10
+
11
+ def normalize_str(input_str, remove_punct=True) -> str:
12
+ no_spaces = re.sub(r"\s", "", input_str)
13
+ if remove_punct:
14
+ translator = str.maketrans("", "", string.punctuation)
15
+ return no_spaces.lower().translate(translator)
16
+ else:
17
+ return no_spaces.lower()
18
+
19
+
20
+ def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]:
21
+ if char_list is None:
22
+ char_list = [",", ";"]
23
+ pattern = f"[{''.join(char_list)}]"
24
+ return re.split(pattern, s)
25
+
26
+
27
+ def normalize_number_str(number_str: str) -> float:
28
+ for char in ["$", "%", ","]:
29
+ number_str = number_str.replace(char, "")
30
+ try:
31
+ return float(number_str)
32
+ except ValueError:
33
+ logger.error(f"String {number_str} cannot be normalized to number str.")
34
+ return float("inf")
35
+
36
+
37
+ def question_scorer(model_answer: str, ground_truth: str) -> bool:
38
+ def is_float(element: Any) -> bool:
39
+ try:
40
+ float(element)
41
+ return True
42
+ except ValueError:
43
+ return False
44
+
45
+ try:
46
+ if is_float(ground_truth):
47
+ logger.info(f"Evaluating {model_answer} as a number.")
48
+ normalized_answer = normalize_number_str(model_answer)
49
+ return normalized_answer == float(ground_truth)
50
+
51
+ elif any(char in ground_truth for char in [",", ";"]):
52
+ logger.info(f"Evaluating {model_answer} as a comma separated list.")
53
+ gt_elems = split_string(ground_truth)
54
+ ma_elems = split_string(model_answer)
55
+
56
+ if len(gt_elems) != len(ma_elems):
57
+ logger.warning("Answer lists have different lengths, returning False.")
58
+ return False
59
+
60
+ comparisons = []
61
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
62
+ if is_float(gt_elem):
63
+ normalized_ma_elem = normalize_number_str(ma_elem)
64
+ comparisons.append(normalized_ma_elem == float(gt_elem))
65
+ else:
66
+ ma_elem = normalize_str(ma_elem, remove_punct=False)
67
+ gt_elem = normalize_str(gt_elem, remove_punct=False)
68
+ comparisons.append(ma_elem == gt_elem)
69
+ return all(comparisons)
70
+ else:
71
+ logger.info(f"Evaluating {model_answer} as a string.")
72
+ ma_elem = normalize_str(model_answer)
73
+ gt_elem = normalize_str(ground_truth)
74
+ return ma_elem == gt_elem
75
+ except Exception as e:
76
+ logger.error(f"Error during evaluation: {e}")
77
+ return False
78
+
79
+
80
+ def load_dataset_meta(path: str, split: str = "validation"):
81
+ data_dir = Path(path) / split
82
+
83
+ dataset = []
84
+ with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
85
+ lines = metaf.readlines()
86
+ for line in lines:
87
+ data = json.loads(line)
88
+ if data["task_id"] == "0-0-0-0-0":
89
+ continue
90
+ if data["file_name"]:
91
+ data["file_name"] = data_dir / data["file_name"]
92
+ dataset.append(data)
93
+ return dataset
94
+
95
+
96
+ def load_dataset_meta_dict(path: str, split: str = "validation"):
97
+ data_dir = Path(path) / split
98
+
99
+ dataset = {}
100
+ with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
101
+ lines = metaf.readlines()
102
+ for line in lines:
103
+ data = json.loads(line)
104
+ if data["task_id"] == "0-0-0-0-0":
105
+ continue
106
+ if data["file_name"]:
107
+ data["file_name"] = data_dir / data["file_name"]
108
+ dataset[data["task_id"]] = data
109
+ return dataset
110
+
111
+
112
+ def add_file_path(
113
+ task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation"
114
+ ):
115
+ if task["file_name"]:
116
+ file_path = Path(f"{file_path}/{split}") / task["file_name"]
117
+ if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
118
+ task["Question"] += f" Here are the necessary document files: {file_path}"
119
+
120
+ elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
121
+ task["Question"] += f" Here are the necessary image files: {file_path}"
122
+
123
+ elif file_path.suffix in [".xlsx", "xls", ".csv"]:
124
+ task["Question"] += (
125
+ f" Here are the necessary table files: {file_path}, for processing excel file,"
126
+ " you can use the excel tool or write python code to process the file"
127
+ " step-by-step and get the information."
128
+ )
129
+ elif file_path.suffix in [".py"]:
130
+ task["Question"] += f" Here are the necessary python files: {file_path}"
131
+
132
+ else:
133
+ task["Question"] += f" Here are the necessary files: {file_path}"
134
+
135
+ return task
136
+
137
+
138
+ def report_results(entries):
139
+ # Initialize counters
140
+ total_entries = len(entries)
141
+ total_correct = 0
142
+
143
+ # Initialize level statistics
144
+ level_stats = {}
145
+
146
+ # Process each entry
147
+ for entry in entries:
148
+ level = entry.get("level")
149
+ is_correct = entry.get("is_correct", False)
150
+
151
+ # Initialize level stats if not already present
152
+ if level not in level_stats:
153
+ level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0}
154
+
155
+ # Update counters
156
+ level_stats[level]["total"] += 1
157
+ if is_correct:
158
+ total_correct += 1
159
+ level_stats[level]["correct"] += 1
160
+
161
+ # Calculate accuracy for each level
162
+ for level, stats in level_stats.items():
163
+ if stats["total"] > 0:
164
+ stats["accuracy"] = (stats["correct"] / stats["total"]) * 100
165
+
166
+ # Print overall statistics with colorful logging
167
+ logger.info("Overall Statistics:")
168
+ overall_accuracy = (total_correct / total_entries) * 100
169
+
170
+ # Create overall statistics table
171
+ overall_table = [
172
+ ["Total Entries", total_entries],
173
+ ["Total Correct", total_correct],
174
+ ["Overall Accuracy", f"{overall_accuracy:.2f}%"],
175
+ ]
176
+ logger.success(tabulate(overall_table, tablefmt="grid"))
177
+ logger.info("")
178
+
179
+ # Create level statistics table
180
+ logger.info("Statistics by Level:")
181
+ level_table = []
182
+ headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"]
183
+
184
+ for level in sorted(level_stats.keys()):
185
+ stats = level_stats[level]
186
+ level_table.append(
187
+ [level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"]
188
+ )
189
+
190
+ logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))
191
+
192
+
193
+ import uuid
194
+ import time
195
+
196
+ from typing import List
197
+
198
+ import inspect
199
+ from typing import get_type_hints, Tuple
200
+
201
+
202
+ def stream_message_template(model: str, message: str):
203
+ return {
204
+ "id": f"{model}-{str(uuid.uuid4())}",
205
+ "object": "chat.completion.chunk",
206
+ "created": int(time.time()),
207
+ "model": model,
208
+ "choices": [
209
+ {
210
+ "index": 0,
211
+ "delta": {"content": message},
212
+ "logprobs": None,
213
+ "finish_reason": None,
214
+ }
215
+ ],
216
+ }
217
+
218
+
219
+ def get_last_user_message(messages: List[dict]) -> str:
220
+ for message in reversed(messages):
221
+ if message["role"] == "user":
222
+ if isinstance(message["content"], list):
223
+ for item in message["content"]:
224
+ if item["type"] == "text":
225
+ return item["text"]
226
+ return message["content"]
227
+ return None
228
+
229
+
230
+ def get_last_assistant_message(messages: List[dict]) -> str:
231
+ for message in reversed(messages):
232
+ if message["role"] == "assistant":
233
+ if isinstance(message["content"], list):
234
+ for item in message["content"]:
235
+ if item["type"] == "text":
236
+ return item["text"]
237
+ return message["content"]
238
+ return None
239
+
240
+
241
+ def get_system_message(messages: List[dict]) -> dict:
242
+ for message in messages:
243
+ if message["role"] == "system":
244
+ return message
245
+ return None
246
+
247
+
248
+ def remove_system_message(messages: List[dict]) -> List[dict]:
249
+ return [message for message in messages if message["role"] != "system"]
250
+
251
+
252
+ def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
253
+ return get_system_message(messages), remove_system_message(messages)
254
+
255
+
256
+ def add_or_update_system_message(content: str, messages: List[dict]) -> List[dict]:
257
+ """
258
+ Adds a new system message at the beginning of the messages list
259
+ or updates the existing system message at the beginning.
260
+
261
+ :param msg: The message to be added or appended.
262
+ :param messages: The list of message dictionaries.
263
+ :return: The updated list of message dictionaries.
264
+ """
265
+
266
+ if messages and messages[0].get("role") == "system":
267
+ messages[0]["content"] += f"{content}\n{messages[0]['content']}"
268
+ else:
269
+ # Insert at the beginning
270
+ messages.insert(0, {"role": "system", "content": content})
271
+
272
+ return messages
273
+
274
+
275
+ def doc_to_dict(docstring):
276
+ lines = docstring.split("\n")
277
+ description = lines[1].strip()
278
+ param_dict = {}
279
+
280
+ for line in lines:
281
+ if ":param" in line:
282
+ line = line.replace(":param", "").strip()
283
+ param, desc = line.split(":", 1)
284
+ param_dict[param.strip()] = desc.strip()
285
+ ret_dict = {"description": description, "params": param_dict}
286
+ return ret_dict
287
+
288
+
289
+ def get_tools_specs(tools) -> List[dict]:
290
+ function_list = [
291
+ {"name": func, "function": getattr(tools, func)}
292
+ for func in dir(tools)
293
+ if callable(getattr(tools, func)) and not func.startswith("__")
294
+ ]
295
+
296
+ specs = []
297
+
298
+ for function_item in function_list:
299
+ function_name = function_item["name"]
300
+ function = function_item["function"]
301
+
302
+ function_doc = doc_to_dict(function.__doc__ or function_name)
303
+ specs.append(
304
+ {
305
+ "name": function_name,
306
+ # TODO: multi-line desc?
307
+ "description": function_doc.get("description", function_name),
308
+ "parameters": {
309
+ "type": "object",
310
+ "properties": {
311
+ param_name: {
312
+ "type": param_annotation.__name__.lower(),
313
+ **(
314
+ {
315
+ "enum": (
316
+ param_annotation.__args__
317
+ if hasattr(param_annotation, "__args__")
318
+ else None
319
+ )
320
+ }
321
+ if hasattr(param_annotation, "__args__")
322
+ else {}
323
+ ),
324
+ "description": function_doc.get("params", {}).get(
325
+ param_name, param_name
326
+ ),
327
+ }
328
+ for param_name, param_annotation in get_type_hints(
329
+ function
330
+ ).items()
331
+ if param_name != "return"
332
+ },
333
+ "required": [
334
+ name
335
+ for name, param in inspect.signature(
336
+ function
337
+ ).parameters.items()
338
+ if param.default is param.empty
339
+ ],
340
+ },
341
+ }
342
+ )
343
+
344
+ return specs