Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files
AWorld-main/aworlddistributed/aworldspace/utils/loader.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.util
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import traceback
|
8 |
+
from aworldspace.base import AGENT_SPACE
|
9 |
+
import aworld.trace as trace # noqa
|
10 |
+
|
11 |
+
from config import AGENTS_DIR
|
12 |
+
|
13 |
+
if not os.path.exists(AGENTS_DIR):
|
14 |
+
os.makedirs(AGENTS_DIR)
|
15 |
+
|
16 |
+
PIPELINES = {}
|
17 |
+
PIPELINE_MODULES = {}
|
18 |
+
|
19 |
+
def get_all_pipelines():
|
20 |
+
pipelines = {}
|
21 |
+
for pipeline_id in PIPELINE_MODULES.keys():
|
22 |
+
pipeline = PIPELINE_MODULES[pipeline_id]
|
23 |
+
|
24 |
+
if hasattr(pipeline, "type"):
|
25 |
+
if pipeline.type == "manifold":
|
26 |
+
manifold_pipelines = []
|
27 |
+
|
28 |
+
# Check if pipelines is a function or a list
|
29 |
+
if callable(pipeline.pipelines):
|
30 |
+
manifold_pipelines = pipeline.pipelines()
|
31 |
+
else:
|
32 |
+
manifold_pipelines = pipeline.pipelines
|
33 |
+
|
34 |
+
for p in manifold_pipelines:
|
35 |
+
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
|
36 |
+
|
37 |
+
manifold_pipeline_name = p["name"]
|
38 |
+
if hasattr(pipeline, "name"):
|
39 |
+
manifold_pipeline_name = (
|
40 |
+
f"{pipeline.name}{manifold_pipeline_name}"
|
41 |
+
)
|
42 |
+
|
43 |
+
pipelines[manifold_pipeline_id] = {
|
44 |
+
"module": pipeline_id,
|
45 |
+
"type": pipeline.type if hasattr(pipeline, "type") else "pipe",
|
46 |
+
"id": manifold_pipeline_id,
|
47 |
+
"name": manifold_pipeline_name,
|
48 |
+
"valves": (
|
49 |
+
pipeline.valves if hasattr(pipeline, "valves") else None
|
50 |
+
),
|
51 |
+
}
|
52 |
+
if pipeline.type == "filter":
|
53 |
+
pipelines[pipeline_id] = {
|
54 |
+
"module": pipeline_id,
|
55 |
+
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
|
56 |
+
"id": pipeline_id,
|
57 |
+
"name": (
|
58 |
+
pipeline.name if hasattr(pipeline, "name") else pipeline_id
|
59 |
+
),
|
60 |
+
"pipelines": (
|
61 |
+
pipeline.valves.pipelines
|
62 |
+
if hasattr(pipeline, "valves")
|
63 |
+
and hasattr(pipeline.valves, "pipelines")
|
64 |
+
else []
|
65 |
+
),
|
66 |
+
"priority": (
|
67 |
+
pipeline.valves.priority
|
68 |
+
if hasattr(pipeline, "valves")
|
69 |
+
and hasattr(pipeline.valves, "priority")
|
70 |
+
else 0
|
71 |
+
),
|
72 |
+
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
|
73 |
+
}
|
74 |
+
else:
|
75 |
+
pipelines[pipeline_id] = {
|
76 |
+
"module": pipeline_id,
|
77 |
+
"type": (pipeline.type if hasattr(pipeline, "type") else "pipe"),
|
78 |
+
"id": pipeline_id,
|
79 |
+
"name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
|
80 |
+
"valves": pipeline.valves if hasattr(pipeline, "valves") else None,
|
81 |
+
}
|
82 |
+
|
83 |
+
return pipelines
|
84 |
+
|
85 |
+
|
86 |
+
def parse_frontmatter(content):
|
87 |
+
frontmatter = {}
|
88 |
+
for line in content.split("\n"):
|
89 |
+
if ":" in line:
|
90 |
+
key, value = line.split(":", 1)
|
91 |
+
frontmatter[key.strip().lower()] = value.strip()
|
92 |
+
return frontmatter
|
93 |
+
|
94 |
+
|
95 |
+
def install_frontmatter_requirements(requirements):
|
96 |
+
if requirements:
|
97 |
+
req_list = [req.strip() for req in requirements.split(",")]
|
98 |
+
for req in req_list:
|
99 |
+
print(f"Installing requirement: {req}")
|
100 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", req])
|
101 |
+
else:
|
102 |
+
print("No requirements found in frontmatter.")
|
103 |
+
|
104 |
+
|
105 |
+
async def load_module_from_path(module_name, module_path):
|
106 |
+
|
107 |
+
try:
|
108 |
+
# Read the module content
|
109 |
+
with open(module_path, "r") as file:
|
110 |
+
content = file.read()
|
111 |
+
|
112 |
+
# Parse frontmatter
|
113 |
+
frontmatter = {}
|
114 |
+
if content.startswith('"""'):
|
115 |
+
end = content.find('"""', 3)
|
116 |
+
if end != -1:
|
117 |
+
frontmatter_content = content[3:end]
|
118 |
+
frontmatter = parse_frontmatter(frontmatter_content)
|
119 |
+
|
120 |
+
# Install requirements if specified
|
121 |
+
if "requirements" in frontmatter:
|
122 |
+
install_frontmatter_requirements(frontmatter["requirements"])
|
123 |
+
|
124 |
+
# Load the module
|
125 |
+
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
126 |
+
module = importlib.util.module_from_spec(spec)
|
127 |
+
spec.loader.exec_module(module)
|
128 |
+
logging.info(f"Loaded module start: {module.__name__}")
|
129 |
+
if hasattr(module, "Pipeline"):
|
130 |
+
return module.Pipeline()
|
131 |
+
else:
|
132 |
+
logging.info(f"Loaded module failed: {module.__name__ } No Pipeline class found")
|
133 |
+
raise Exception("No Pipeline class found")
|
134 |
+
except Exception as e:
|
135 |
+
logging.info(f"Error loading module: {module_name}, error is {e}")
|
136 |
+
traceback.print_exc()
|
137 |
+
# Move the file to the error folder
|
138 |
+
failed_pipelines_folder = os.path.join(AGENTS_DIR, "failed")
|
139 |
+
if not os.path.exists(failed_pipelines_folder):
|
140 |
+
os.makedirs(failed_pipelines_folder)
|
141 |
+
|
142 |
+
# failed_file_path = os.path.join(failed_pipelines_folder, f"{module_name}.py")
|
143 |
+
# if module_path.__contains__(PIPELINES_DIR):
|
144 |
+
# os.rename(module_path, failed_file_path)
|
145 |
+
print(e)
|
146 |
+
return None
|
147 |
+
|
148 |
+
|
149 |
+
async def load_modules_from_directory(directory):
|
150 |
+
logging.info(f"load_modules_from_directory: {directory}")
|
151 |
+
global PIPELINE_MODULES
|
152 |
+
|
153 |
+
for filename in os.listdir(directory):
|
154 |
+
if filename.endswith(".py"):
|
155 |
+
module_name = filename[:-3] # Remove the .py extension
|
156 |
+
module_path = os.path.join(directory, filename)
|
157 |
+
|
158 |
+
# Create subfolder matching the filename without the .py extension
|
159 |
+
subfolder_path = os.path.join(directory, module_name)
|
160 |
+
if not os.path.exists(subfolder_path):
|
161 |
+
os.makedirs(subfolder_path)
|
162 |
+
logging.info(f"Created subfolder: {subfolder_path}")
|
163 |
+
|
164 |
+
# Create a valves.json file if it doesn't exist
|
165 |
+
valves_json_path = os.path.join(subfolder_path, "valves.json")
|
166 |
+
if not os.path.exists(valves_json_path):
|
167 |
+
with open(valves_json_path, "w") as f:
|
168 |
+
json.dump({}, f)
|
169 |
+
logging.info(f"Created valves.json in: {subfolder_path}")
|
170 |
+
|
171 |
+
pipeline = await load_module_from_path(module_name, module_path)
|
172 |
+
if pipeline:
|
173 |
+
# Overwrite pipeline.valves with values from valves.json
|
174 |
+
if os.path.exists(valves_json_path):
|
175 |
+
with open(valves_json_path, "r") as f:
|
176 |
+
valves_json = json.load(f)
|
177 |
+
if hasattr(pipeline, "valves"):
|
178 |
+
ValvesModel = pipeline.valves.__class__
|
179 |
+
# Create a ValvesModel instance using default values and overwrite with valves_json
|
180 |
+
combined_valves = {
|
181 |
+
**pipeline.valves.model_dump(),
|
182 |
+
**valves_json,
|
183 |
+
}
|
184 |
+
valves = ValvesModel(**combined_valves)
|
185 |
+
pipeline.valves = valves
|
186 |
+
|
187 |
+
logging.info(f"Updated valves for module: {module_name}")
|
188 |
+
|
189 |
+
pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name
|
190 |
+
PIPELINE_MODULES[pipeline_id] = pipeline
|
191 |
+
|
192 |
+
logging.info(f"Loaded module success: {module_name}")
|
193 |
+
else:
|
194 |
+
logging.warning(f"No Pipeline class found in {module_name}")
|
195 |
+
|
196 |
+
AGENT_SPACE.agent_modules = PIPELINE_MODULES
|
197 |
+
AGENT_SPACE.agents_meta = get_all_pipelines()
|
AWorld-main/aworlddistributed/aworldspace/utils/log.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from datetime import datetime
|
4 |
+
|
5 |
+
from aworld.models.model_response import ModelResponse
|
6 |
+
|
7 |
+
from base import AworldTask, AworldTaskResult
|
8 |
+
from config import ROOT_LOG
|
9 |
+
|
10 |
+
|
11 |
+
class TaskLogger:
|
12 |
+
"""任务提交日志记录器"""
|
13 |
+
|
14 |
+
def __init__(self, log_file: str = "aworld_task_submissions.log"):
|
15 |
+
self.log_file = os.path.join(ROOT_LOG, 'task_logs' , log_file)
|
16 |
+
self._ensure_log_file_exists()
|
17 |
+
|
18 |
+
def _ensure_log_file_exists(self):
|
19 |
+
"""确保日志文件存在"""
|
20 |
+
if not os.path.exists(self.log_file):
|
21 |
+
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
|
22 |
+
with open(self.log_file, 'w', encoding='utf-8') as f:
|
23 |
+
f.write("# Aworld Task Submission Log\n")
|
24 |
+
f.write(
|
25 |
+
"# Format: [timestamp] task_id | agent_id | server | status | agent_answer | correct_answer | is_correct | details\n\n")
|
26 |
+
|
27 |
+
def log_task_submission(self, task: AworldTask, status: str, details: str = "",
|
28 |
+
task_result: AworldTaskResult = None):
|
29 |
+
"""记录任务提交日志"""
|
30 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
31 |
+
log_entry = f"[{timestamp}] {task.task_id} | {task.agent_id} | {task.node_id} | {status} | {task_result.data.get('agent_answer') if task_result and task_result.data else None} | {task_result.data.get('correct_answer') if task_result and task_result.data else None} | {task_result.data.get('gaia_correct') if task_result and task_result.data else None} |{details}\n"
|
32 |
+
|
33 |
+
try:
|
34 |
+
with open(self.log_file, 'a', encoding='utf-8') as f:
|
35 |
+
f.write(log_entry)
|
36 |
+
except Exception as e:
|
37 |
+
logging.error(f"Failed to write task submission log: {e}")
|
38 |
+
|
39 |
+
def log_task_result(self, task: AworldTask, result: ModelResponse):
|
40 |
+
try:
|
41 |
+
date_str = datetime.now().strftime("%Y%m%d")
|
42 |
+
result_dir = os.path.join(ROOT_LOG, 'task_logs', 'result', date_str)
|
43 |
+
os.makedirs(result_dir, exist_ok=True)
|
44 |
+
|
45 |
+
md_file = f"{result_dir}/{task.task_id}.md"
|
46 |
+
|
47 |
+
content_parts = []
|
48 |
+
if hasattr(result, 'content') and result.content:
|
49 |
+
if isinstance(result.content, list):
|
50 |
+
content_parts.extend(result.content)
|
51 |
+
else:
|
52 |
+
content_parts.append(str(result.content))
|
53 |
+
|
54 |
+
file_exists = os.path.exists(md_file)
|
55 |
+
with open(md_file, 'a', encoding='utf-8') as f:
|
56 |
+
if not file_exists:
|
57 |
+
f.write(f"# Task Result: {task.task_id}\n\n")
|
58 |
+
f.write(f"**Agent ID:** {task.agent_id}\n\n")
|
59 |
+
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
60 |
+
f.write("## Content\n\n")
|
61 |
+
|
62 |
+
if content_parts:
|
63 |
+
for i, content in enumerate(content_parts, 1):
|
64 |
+
f.write(f"{content}\n\n")
|
65 |
+
else:
|
66 |
+
f.write("No content available.\n\n")
|
67 |
+
|
68 |
+
return md_file
|
69 |
+
|
70 |
+
except Exception as e:
|
71 |
+
logging.error(f"Failed to write task result log: {e}")
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
task_logger = TaskLogger(log_file=f"aworld_task_submissions_{datetime.now().strftime('%Y%m%d')}.log")
|
AWorld-main/aworlddistributed/aworldspace/utils/mcp_utils.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def load_all_mcp_config():
|
5 |
+
return {
|
6 |
+
"mcpServers": {
|
7 |
+
"e2b-server": {
|
8 |
+
"command": "npx",
|
9 |
+
"args": [
|
10 |
+
"-y",
|
11 |
+
"@e2b/mcp-server"
|
12 |
+
],
|
13 |
+
"env": {
|
14 |
+
"E2B_API_KEY": os.environ["E2B_API_KEY"],
|
15 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"filesystem": {
|
19 |
+
"command": "npx",
|
20 |
+
"args": [
|
21 |
+
"-y",
|
22 |
+
"@modelcontextprotocol/server-filesystem",
|
23 |
+
"${FILESYSTEM_SERVER_WORKDIR}"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
"terminal-controller": {
|
27 |
+
"command": "python",
|
28 |
+
"args": [
|
29 |
+
"-m",
|
30 |
+
"terminal_controller"
|
31 |
+
],
|
32 |
+
"env": {
|
33 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "300"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"calculator": {
|
37 |
+
"command": "python",
|
38 |
+
"args": [
|
39 |
+
"-m",
|
40 |
+
"mcp_server_calculator"
|
41 |
+
],
|
42 |
+
"env": {
|
43 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "20"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"excel": {
|
47 |
+
"command": "uvx",
|
48 |
+
"args": ["excel-mcp-server", "stdio"],
|
49 |
+
"env": {
|
50 |
+
"EXCEL_MCP_PAGING_CELLS_LIMIT": "4000",
|
51 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
52 |
+
}
|
53 |
+
},
|
54 |
+
"google-search": {
|
55 |
+
"command": "npx",
|
56 |
+
"args": [
|
57 |
+
"-y",
|
58 |
+
"@adenot/mcp-google-search"
|
59 |
+
],
|
60 |
+
"env": {
|
61 |
+
"GOOGLE_API_KEY": os.environ["GOOGLE_API_KEY"],
|
62 |
+
"GOOGLE_SEARCH_ENGINE_ID": os.environ["GOOGLE_CSE_ID"],
|
63 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"ms-playwright": {
|
67 |
+
"command": "npx",
|
68 |
+
"args": [
|
69 |
+
"@playwright/mcp@latest",
|
70 |
+
"--no-sandbox",
|
71 |
+
"--headless",
|
72 |
+
"--isolated"
|
73 |
+
],
|
74 |
+
"env": {
|
75 |
+
"PLAYWRIGHT_TIMEOUT": "120000",
|
76 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"audio_server": {
|
80 |
+
"command": "python",
|
81 |
+
"args": [
|
82 |
+
"-m",
|
83 |
+
"mcp_servers.audio_server"
|
84 |
+
],
|
85 |
+
"env": {
|
86 |
+
"AUDIO_LLM_API_KEY": os.environ["AUDIO_LLM_API_KEY"],
|
87 |
+
"AUDIO_LLM_BASE_URL": os.environ["AUDIO_LLM_BASE_URL"],
|
88 |
+
"AUDIO_LLM_MODEL_NAME": os.environ["AUDIO_LLM_MODEL_NAME"],
|
89 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
90 |
+
}
|
91 |
+
},
|
92 |
+
"image_server": {
|
93 |
+
"command": "python",
|
94 |
+
"args": [
|
95 |
+
"-m",
|
96 |
+
"mcp_servers.image_server"
|
97 |
+
],
|
98 |
+
"env": {
|
99 |
+
"LLM_API_KEY": os.environ.get("LLM_API_KEY"),
|
100 |
+
"LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
|
101 |
+
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
|
102 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"youtube_server": {
|
106 |
+
"command": "python",
|
107 |
+
"args": [
|
108 |
+
"-m",
|
109 |
+
"mcp_servers.youtube_server"
|
110 |
+
],
|
111 |
+
"env": {
|
112 |
+
"CHROME_DRIVER_PATH": os.environ['CHROME_DRIVER_PATH'],
|
113 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
114 |
+
}
|
115 |
+
},
|
116 |
+
"video_server": {
|
117 |
+
"command": "python",
|
118 |
+
"args": [
|
119 |
+
"-m",
|
120 |
+
"mcp_servers.video_server"
|
121 |
+
],
|
122 |
+
"env": {
|
123 |
+
"LLM_API_KEY": os.environ.get("LLM_API_KEY"),
|
124 |
+
"LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
|
125 |
+
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
|
126 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
127 |
+
}
|
128 |
+
},
|
129 |
+
"search_server": {
|
130 |
+
"command": "python",
|
131 |
+
"args": [
|
132 |
+
"-m",
|
133 |
+
"mcp_servers.search_server"
|
134 |
+
],
|
135 |
+
"env": {
|
136 |
+
"GOOGLE_API_KEY": os.environ["GOOGLE_API_KEY"],
|
137 |
+
"GOOGLE_CSE_ID": os.environ["GOOGLE_CSE_ID"],
|
138 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "60"
|
139 |
+
}
|
140 |
+
},
|
141 |
+
"download_server": {
|
142 |
+
"command": "python",
|
143 |
+
"args": [
|
144 |
+
"-m",
|
145 |
+
"mcp_servers.download_server"
|
146 |
+
],
|
147 |
+
"env": {
|
148 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
149 |
+
}
|
150 |
+
},
|
151 |
+
"document_server": {
|
152 |
+
"command": "python",
|
153 |
+
"args": [
|
154 |
+
"-m",
|
155 |
+
"mcp_servers.document_server"
|
156 |
+
],
|
157 |
+
"env": {
|
158 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
159 |
+
}
|
160 |
+
},
|
161 |
+
"browser_server": {
|
162 |
+
"command": "python",
|
163 |
+
"args": [
|
164 |
+
"-m",
|
165 |
+
"mcp_servers.browser_server"
|
166 |
+
],
|
167 |
+
"env": {
|
168 |
+
"LLM_API_KEY": os.environ.get("LLM_API_KEY"),
|
169 |
+
"LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
|
170 |
+
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
|
171 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
172 |
+
}
|
173 |
+
},
|
174 |
+
"reasoning_server": {
|
175 |
+
"command": "python",
|
176 |
+
"args": [
|
177 |
+
"-m",
|
178 |
+
"mcp_servers.reasoning_server"
|
179 |
+
],
|
180 |
+
"env": {
|
181 |
+
"LLM_API_KEY": os.environ.get("LLM_API_KEY"),
|
182 |
+
"LLM_MODEL_NAME": os.environ.get("LLM_MODEL_NAME"),
|
183 |
+
"LLM_BASE_URL": os.environ.get("LLM_BASE_URL"),
|
184 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
185 |
+
}
|
186 |
+
},
|
187 |
+
"e2b-code-server": {
|
188 |
+
"command": "python",
|
189 |
+
"args": [
|
190 |
+
"-m",
|
191 |
+
"mcp_servers.e2b_code_server"
|
192 |
+
],
|
193 |
+
"env": {
|
194 |
+
"E2B_API_KEY": os.environ["E2B_API_KEY"],
|
195 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
196 |
+
}
|
197 |
+
},
|
198 |
+
}
|
199 |
+
}
|
AWorld-main/aworlddistributed/aworldspace/utils/utils.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List, Optional
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
from tabulate import tabulate
|
9 |
+
|
10 |
+
|
11 |
+
def normalize_str(input_str, remove_punct=True) -> str:
|
12 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
13 |
+
if remove_punct:
|
14 |
+
translator = str.maketrans("", "", string.punctuation)
|
15 |
+
return no_spaces.lower().translate(translator)
|
16 |
+
else:
|
17 |
+
return no_spaces.lower()
|
18 |
+
|
19 |
+
|
20 |
+
def split_string(s: str, char_list: Optional[List[str]] = None) -> list[str]:
|
21 |
+
if char_list is None:
|
22 |
+
char_list = [",", ";"]
|
23 |
+
pattern = f"[{''.join(char_list)}]"
|
24 |
+
return re.split(pattern, s)
|
25 |
+
|
26 |
+
|
27 |
+
def normalize_number_str(number_str: str) -> float:
|
28 |
+
for char in ["$", "%", ","]:
|
29 |
+
number_str = number_str.replace(char, "")
|
30 |
+
try:
|
31 |
+
return float(number_str)
|
32 |
+
except ValueError:
|
33 |
+
logger.error(f"String {number_str} cannot be normalized to number str.")
|
34 |
+
return float("inf")
|
35 |
+
|
36 |
+
|
37 |
+
def question_scorer(model_answer: str, ground_truth: str) -> bool:
|
38 |
+
def is_float(element: Any) -> bool:
|
39 |
+
try:
|
40 |
+
float(element)
|
41 |
+
return True
|
42 |
+
except ValueError:
|
43 |
+
return False
|
44 |
+
|
45 |
+
try:
|
46 |
+
if is_float(ground_truth):
|
47 |
+
logger.info(f"Evaluating {model_answer} as a number.")
|
48 |
+
normalized_answer = normalize_number_str(model_answer)
|
49 |
+
return normalized_answer == float(ground_truth)
|
50 |
+
|
51 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
52 |
+
logger.info(f"Evaluating {model_answer} as a comma separated list.")
|
53 |
+
gt_elems = split_string(ground_truth)
|
54 |
+
ma_elems = split_string(model_answer)
|
55 |
+
|
56 |
+
if len(gt_elems) != len(ma_elems):
|
57 |
+
logger.warning("Answer lists have different lengths, returning False.")
|
58 |
+
return False
|
59 |
+
|
60 |
+
comparisons = []
|
61 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
62 |
+
if is_float(gt_elem):
|
63 |
+
normalized_ma_elem = normalize_number_str(ma_elem)
|
64 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
65 |
+
else:
|
66 |
+
ma_elem = normalize_str(ma_elem, remove_punct=False)
|
67 |
+
gt_elem = normalize_str(gt_elem, remove_punct=False)
|
68 |
+
comparisons.append(ma_elem == gt_elem)
|
69 |
+
return all(comparisons)
|
70 |
+
else:
|
71 |
+
logger.info(f"Evaluating {model_answer} as a string.")
|
72 |
+
ma_elem = normalize_str(model_answer)
|
73 |
+
gt_elem = normalize_str(ground_truth)
|
74 |
+
return ma_elem == gt_elem
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Error during evaluation: {e}")
|
77 |
+
return False
|
78 |
+
|
79 |
+
|
80 |
+
def load_dataset_meta(path: str, split: str = "validation"):
|
81 |
+
data_dir = Path(path) / split
|
82 |
+
|
83 |
+
dataset = []
|
84 |
+
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
|
85 |
+
lines = metaf.readlines()
|
86 |
+
for line in lines:
|
87 |
+
data = json.loads(line)
|
88 |
+
if data["task_id"] == "0-0-0-0-0":
|
89 |
+
continue
|
90 |
+
if data["file_name"]:
|
91 |
+
data["file_name"] = data_dir / data["file_name"]
|
92 |
+
dataset.append(data)
|
93 |
+
return dataset
|
94 |
+
|
95 |
+
|
96 |
+
def load_dataset_meta_dict(path: str, split: str = "validation"):
|
97 |
+
data_dir = Path(path) / split
|
98 |
+
|
99 |
+
dataset = {}
|
100 |
+
with open(data_dir / "metadata.jsonl", "r", encoding="utf-8") as metaf:
|
101 |
+
lines = metaf.readlines()
|
102 |
+
for line in lines:
|
103 |
+
data = json.loads(line)
|
104 |
+
if data["task_id"] == "0-0-0-0-0":
|
105 |
+
continue
|
106 |
+
if data["file_name"]:
|
107 |
+
data["file_name"] = data_dir / data["file_name"]
|
108 |
+
dataset[data["task_id"]] = data
|
109 |
+
return dataset
|
110 |
+
|
111 |
+
|
112 |
+
def add_file_path(
|
113 |
+
task: Dict[str, Any], file_path: str = "./gaia_dataset", split: str = "validation"
|
114 |
+
):
|
115 |
+
if task["file_name"]:
|
116 |
+
file_path = Path(f"{file_path}/{split}") / task["file_name"]
|
117 |
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
118 |
+
task["Question"] += f" Here are the necessary document files: {file_path}"
|
119 |
+
|
120 |
+
elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
|
121 |
+
task["Question"] += f" Here are the necessary image files: {file_path}"
|
122 |
+
|
123 |
+
elif file_path.suffix in [".xlsx", "xls", ".csv"]:
|
124 |
+
task["Question"] += (
|
125 |
+
f" Here are the necessary table files: {file_path}, for processing excel file,"
|
126 |
+
" you can use the excel tool or write python code to process the file"
|
127 |
+
" step-by-step and get the information."
|
128 |
+
)
|
129 |
+
elif file_path.suffix in [".py"]:
|
130 |
+
task["Question"] += f" Here are the necessary python files: {file_path}"
|
131 |
+
|
132 |
+
else:
|
133 |
+
task["Question"] += f" Here are the necessary files: {file_path}"
|
134 |
+
|
135 |
+
return task
|
136 |
+
|
137 |
+
|
138 |
+
def report_results(entries):
|
139 |
+
# Initialize counters
|
140 |
+
total_entries = len(entries)
|
141 |
+
total_correct = 0
|
142 |
+
|
143 |
+
# Initialize level statistics
|
144 |
+
level_stats = {}
|
145 |
+
|
146 |
+
# Process each entry
|
147 |
+
for entry in entries:
|
148 |
+
level = entry.get("level")
|
149 |
+
is_correct = entry.get("is_correct", False)
|
150 |
+
|
151 |
+
# Initialize level stats if not already present
|
152 |
+
if level not in level_stats:
|
153 |
+
level_stats[level] = {"total": 0, "correct": 0, "accuracy": 0}
|
154 |
+
|
155 |
+
# Update counters
|
156 |
+
level_stats[level]["total"] += 1
|
157 |
+
if is_correct:
|
158 |
+
total_correct += 1
|
159 |
+
level_stats[level]["correct"] += 1
|
160 |
+
|
161 |
+
# Calculate accuracy for each level
|
162 |
+
for level, stats in level_stats.items():
|
163 |
+
if stats["total"] > 0:
|
164 |
+
stats["accuracy"] = (stats["correct"] / stats["total"]) * 100
|
165 |
+
|
166 |
+
# Print overall statistics with colorful logging
|
167 |
+
logger.info("Overall Statistics:")
|
168 |
+
overall_accuracy = (total_correct / total_entries) * 100
|
169 |
+
|
170 |
+
# Create overall statistics table
|
171 |
+
overall_table = [
|
172 |
+
["Total Entries", total_entries],
|
173 |
+
["Total Correct", total_correct],
|
174 |
+
["Overall Accuracy", f"{overall_accuracy:.2f}%"],
|
175 |
+
]
|
176 |
+
logger.success(tabulate(overall_table, tablefmt="grid"))
|
177 |
+
logger.info("")
|
178 |
+
|
179 |
+
# Create level statistics table
|
180 |
+
logger.info("Statistics by Level:")
|
181 |
+
level_table = []
|
182 |
+
headers = ["Level", "Total Entries", "Correct Answers", "Accuracy"]
|
183 |
+
|
184 |
+
for level in sorted(level_stats.keys()):
|
185 |
+
stats = level_stats[level]
|
186 |
+
level_table.append(
|
187 |
+
[level, stats["total"], stats["correct"], f"{stats['accuracy']:.2f}%"]
|
188 |
+
)
|
189 |
+
|
190 |
+
logger.success(tabulate(level_table, headers=headers, tablefmt="grid"))
|
191 |
+
|
192 |
+
|
193 |
+
import uuid
|
194 |
+
import time
|
195 |
+
|
196 |
+
from typing import List
|
197 |
+
|
198 |
+
import inspect
|
199 |
+
from typing import get_type_hints, Tuple
|
200 |
+
|
201 |
+
|
202 |
+
def stream_message_template(model: str, message: str):
|
203 |
+
return {
|
204 |
+
"id": f"{model}-{str(uuid.uuid4())}",
|
205 |
+
"object": "chat.completion.chunk",
|
206 |
+
"created": int(time.time()),
|
207 |
+
"model": model,
|
208 |
+
"choices": [
|
209 |
+
{
|
210 |
+
"index": 0,
|
211 |
+
"delta": {"content": message},
|
212 |
+
"logprobs": None,
|
213 |
+
"finish_reason": None,
|
214 |
+
}
|
215 |
+
],
|
216 |
+
}
|
217 |
+
|
218 |
+
|
219 |
+
def get_last_user_message(messages: List[dict]) -> str:
|
220 |
+
for message in reversed(messages):
|
221 |
+
if message["role"] == "user":
|
222 |
+
if isinstance(message["content"], list):
|
223 |
+
for item in message["content"]:
|
224 |
+
if item["type"] == "text":
|
225 |
+
return item["text"]
|
226 |
+
return message["content"]
|
227 |
+
return None
|
228 |
+
|
229 |
+
|
230 |
+
def get_last_assistant_message(messages: List[dict]) -> str:
|
231 |
+
for message in reversed(messages):
|
232 |
+
if message["role"] == "assistant":
|
233 |
+
if isinstance(message["content"], list):
|
234 |
+
for item in message["content"]:
|
235 |
+
if item["type"] == "text":
|
236 |
+
return item["text"]
|
237 |
+
return message["content"]
|
238 |
+
return None
|
239 |
+
|
240 |
+
|
241 |
+
def get_system_message(messages: List[dict]) -> dict:
|
242 |
+
for message in messages:
|
243 |
+
if message["role"] == "system":
|
244 |
+
return message
|
245 |
+
return None
|
246 |
+
|
247 |
+
|
248 |
+
def remove_system_message(messages: List[dict]) -> List[dict]:
|
249 |
+
return [message for message in messages if message["role"] != "system"]
|
250 |
+
|
251 |
+
|
252 |
+
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
253 |
+
return get_system_message(messages), remove_system_message(messages)
|
254 |
+
|
255 |
+
|
256 |
+
def add_or_update_system_message(content: str, messages: List[dict]) -> List[dict]:
|
257 |
+
"""
|
258 |
+
Adds a new system message at the beginning of the messages list
|
259 |
+
or updates the existing system message at the beginning.
|
260 |
+
|
261 |
+
:param msg: The message to be added or appended.
|
262 |
+
:param messages: The list of message dictionaries.
|
263 |
+
:return: The updated list of message dictionaries.
|
264 |
+
"""
|
265 |
+
|
266 |
+
if messages and messages[0].get("role") == "system":
|
267 |
+
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
|
268 |
+
else:
|
269 |
+
# Insert at the beginning
|
270 |
+
messages.insert(0, {"role": "system", "content": content})
|
271 |
+
|
272 |
+
return messages
|
273 |
+
|
274 |
+
|
275 |
+
def doc_to_dict(docstring):
|
276 |
+
lines = docstring.split("\n")
|
277 |
+
description = lines[1].strip()
|
278 |
+
param_dict = {}
|
279 |
+
|
280 |
+
for line in lines:
|
281 |
+
if ":param" in line:
|
282 |
+
line = line.replace(":param", "").strip()
|
283 |
+
param, desc = line.split(":", 1)
|
284 |
+
param_dict[param.strip()] = desc.strip()
|
285 |
+
ret_dict = {"description": description, "params": param_dict}
|
286 |
+
return ret_dict
|
287 |
+
|
288 |
+
|
289 |
+
def get_tools_specs(tools) -> List[dict]:
|
290 |
+
function_list = [
|
291 |
+
{"name": func, "function": getattr(tools, func)}
|
292 |
+
for func in dir(tools)
|
293 |
+
if callable(getattr(tools, func)) and not func.startswith("__")
|
294 |
+
]
|
295 |
+
|
296 |
+
specs = []
|
297 |
+
|
298 |
+
for function_item in function_list:
|
299 |
+
function_name = function_item["name"]
|
300 |
+
function = function_item["function"]
|
301 |
+
|
302 |
+
function_doc = doc_to_dict(function.__doc__ or function_name)
|
303 |
+
specs.append(
|
304 |
+
{
|
305 |
+
"name": function_name,
|
306 |
+
# TODO: multi-line desc?
|
307 |
+
"description": function_doc.get("description", function_name),
|
308 |
+
"parameters": {
|
309 |
+
"type": "object",
|
310 |
+
"properties": {
|
311 |
+
param_name: {
|
312 |
+
"type": param_annotation.__name__.lower(),
|
313 |
+
**(
|
314 |
+
{
|
315 |
+
"enum": (
|
316 |
+
param_annotation.__args__
|
317 |
+
if hasattr(param_annotation, "__args__")
|
318 |
+
else None
|
319 |
+
)
|
320 |
+
}
|
321 |
+
if hasattr(param_annotation, "__args__")
|
322 |
+
else {}
|
323 |
+
),
|
324 |
+
"description": function_doc.get("params", {}).get(
|
325 |
+
param_name, param_name
|
326 |
+
),
|
327 |
+
}
|
328 |
+
for param_name, param_annotation in get_type_hints(
|
329 |
+
function
|
330 |
+
).items()
|
331 |
+
if param_name != "return"
|
332 |
+
},
|
333 |
+
"required": [
|
334 |
+
name
|
335 |
+
for name, param in inspect.signature(
|
336 |
+
function
|
337 |
+
).parameters.items()
|
338 |
+
if param.default is param.empty
|
339 |
+
],
|
340 |
+
},
|
341 |
+
}
|
342 |
+
)
|
343 |
+
|
344 |
+
return specs
|