Spaces:
Sleeping
Sleeping
Upload aworld_client.py
Browse files
AWorld-main/aworlddistributed/aworld_client.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import AsyncGenerator
|
7 |
+
|
8 |
+
from aworld.models.llm import acall_llm_model, get_llm_model, acall_llm_model_stream
|
9 |
+
from aworld.models.model_response import ModelResponse, LLMResponseError
|
10 |
+
from pydantic import BaseModel, Field
|
11 |
+
|
12 |
+
from base import AworldTask, AworldTaskResult, AworldTaskForm
|
13 |
+
|
14 |
+
|
15 |
+
class TaskLogger:
|
16 |
+
"""Task submission logger"""
|
17 |
+
|
18 |
+
def __init__(self, log_file: str = "aworld_task_submissions.log"):
|
19 |
+
self.log_file = 'task_logs/' + log_file
|
20 |
+
self._ensure_log_file_exists()
|
21 |
+
|
22 |
+
def _ensure_log_file_exists(self):
|
23 |
+
"""ensure log file exists"""
|
24 |
+
if not os.path.exists(self.log_file):
|
25 |
+
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
|
26 |
+
with open(self.log_file, 'w', encoding='utf-8') as f:
|
27 |
+
f.write("# Aworld Task Submission Log\n")
|
28 |
+
f.write("# Format: [timestamp] task_id | agent_id | server | status | agent_answer | correct_answer | is_correct | details\n\n")
|
29 |
+
|
30 |
+
def log_task_submission(self, task: AworldTask, server: str, status: str, details: str = "", task_result: AworldTaskResult = None):
|
31 |
+
"""log task submission"""
|
32 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
33 |
+
log_entry = f"[{timestamp}] {task.task_id} | {task.agent_id} | {task.node_id} | {status} | { task_result.data.get('agent_answer') if task_result and task_result.data else None } | {task_result.data.get('correct_answer') if task_result and task_result.data else None} | {task_result.data.get('gaia_correct') if task_result and task_result.data else None} |{details}\n"
|
34 |
+
|
35 |
+
try:
|
36 |
+
with open(self.log_file, 'a', encoding='utf-8') as f:
|
37 |
+
f.write(log_entry)
|
38 |
+
except Exception as e:
|
39 |
+
logging.error(f"Failed to write task submission log: {e}")
|
40 |
+
|
41 |
+
def log_task_result(self, task: AworldTask, result: ModelResponse):
|
42 |
+
"""log task result to markdown file"""
|
43 |
+
try:
|
44 |
+
# create result directory
|
45 |
+
date_str = datetime.now().strftime("%Y%m%d")
|
46 |
+
result_dir = f"task_logs/result/{date_str}"
|
47 |
+
os.makedirs(result_dir, exist_ok=True)
|
48 |
+
|
49 |
+
# create markdown file
|
50 |
+
md_file = f"{result_dir}/{task.task_id}.md"
|
51 |
+
|
52 |
+
# concat content
|
53 |
+
content_parts = []
|
54 |
+
if hasattr(result, 'content') and result.content:
|
55 |
+
if isinstance(result.content, list):
|
56 |
+
content_parts.extend(result.content)
|
57 |
+
else:
|
58 |
+
content_parts.append(str(result.content))
|
59 |
+
|
60 |
+
# write to markdown file
|
61 |
+
file_exists = os.path.exists(md_file)
|
62 |
+
with open(md_file, 'a', encoding='utf-8') as f:
|
63 |
+
# only write title info when file not exists
|
64 |
+
if not file_exists:
|
65 |
+
f.write(f"# Task Result: {task.task_id}\n\n")
|
66 |
+
f.write(f"**Agent ID:** {task.agent_id}\n\n")
|
67 |
+
f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
68 |
+
f.write("## Content\n\n")
|
69 |
+
|
70 |
+
# write content parts
|
71 |
+
if content_parts:
|
72 |
+
for i, content in enumerate(content_parts, 1):
|
73 |
+
f.write(f"{content}\n\n")
|
74 |
+
else:
|
75 |
+
f.write("No content available.\n\n")
|
76 |
+
|
77 |
+
except Exception as e:
|
78 |
+
logging.error(f"Failed to write task result log: {e}")
|
79 |
+
|
80 |
+
task_logger = TaskLogger(log_file=f"aworld_task_submissions_{datetime.now().strftime('%Y%m%d')}.log")
|
81 |
+
|
82 |
+
class AworldTaskClient(BaseModel):
|
83 |
+
"""
|
84 |
+
AworldTaskClient
|
85 |
+
"""
|
86 |
+
know_hosts: list[str] = Field(default_factory=list, description="aworldserver list")
|
87 |
+
tasks: list[AworldTask] = Field(default_factory=list, description="submitted task list")
|
88 |
+
task_states: dict[str, AworldTaskResult] = Field(default_factory=dict, description="task_states")
|
89 |
+
|
90 |
+
async def submit_task(self, task: AworldTask, background: bool = True):
|
91 |
+
if not self.know_hosts:
|
92 |
+
raise ValueError("No aworld server hosts configured.")
|
93 |
+
# 1. select aworld server from know_hosts using round-robin
|
94 |
+
if not hasattr(self, '_current_server_index'):
|
95 |
+
self._current_server_index = 0
|
96 |
+
aworld_server = self.know_hosts[self._current_server_index]
|
97 |
+
if not aworld_server.startswith("http"):
|
98 |
+
aworld_server = "http://" + aworld_server
|
99 |
+
self._current_server_index = (self._current_server_index + 1) % len(self.know_hosts)
|
100 |
+
|
101 |
+
# 2. call _submit_task
|
102 |
+
result = await self._submit_task(aworld_server, task, background)
|
103 |
+
# 3. update task_states
|
104 |
+
self.task_states[task.task_id] = result
|
105 |
+
|
106 |
+
|
107 |
+
async def _submit_task(self, aworld_server, task: AworldTask, background: bool = True):
|
108 |
+
try:
|
109 |
+
logging.info(f"submit task#{task.task_id} to cluster#[{aworld_server}]")
|
110 |
+
if not background:
|
111 |
+
task_result = await self._submit_task_to_server(aworld_server, task)
|
112 |
+
else:
|
113 |
+
task_result = await self._async_submit_task_to_server(aworld_server, task)
|
114 |
+
return task_result
|
115 |
+
except Exception as e:
|
116 |
+
if isinstance(e, LLMResponseError):
|
117 |
+
if e.message and 'peer closed connection without sending complete message body (incomplete chunked read)' == e.message:
|
118 |
+
task_logger.log_task_submission(task, aworld_server, "server_close_connection", str(e))
|
119 |
+
logging.error(f"execute task to {task.node_id} server_close_connection: [{e}], please see replays wait a moment")
|
120 |
+
return
|
121 |
+
traceback.print_exc()
|
122 |
+
logging.error(f"execute task to {task.node_id} execute_failed: [{e}], please see logs from server ")
|
123 |
+
task_logger.log_task_submission(task, aworld_server, "execute_failed", str(e))
|
124 |
+
|
125 |
+
async def _async_submit_task_to_server(self, aworld_server, task: AworldTask):
|
126 |
+
import httpx
|
127 |
+
from base import AworldTaskForm, AworldTaskResult
|
128 |
+
# 构建 AworldTaskForm
|
129 |
+
form_data = AworldTaskForm(task=task)
|
130 |
+
async with httpx.AsyncClient() as client:
|
131 |
+
resp = await client.post(f"{aworld_server}/api/v1/tasks/submit_task", json=form_data.model_dump())
|
132 |
+
resp.raise_for_status()
|
133 |
+
data = resp.json()
|
134 |
+
task_logger.log_task_submission(task, aworld_server, "submitted")
|
135 |
+
return AworldTaskResult(**data)
|
136 |
+
|
137 |
+
async def _submit_task_to_server(self, aworld_server, task: AworldTask):
|
138 |
+
# build params
|
139 |
+
llm_model = get_llm_model(
|
140 |
+
llm_provider="openai",
|
141 |
+
model_name=task.agent_id,
|
142 |
+
base_url=f"{aworld_server}/v1",
|
143 |
+
api_key="0p3n-w3bu!"
|
144 |
+
)
|
145 |
+
messages = [
|
146 |
+
{"role": "user", "content": task.agent_input}
|
147 |
+
]
|
148 |
+
#call_llm_model
|
149 |
+
data = acall_llm_model_stream(llm_model, messages, stream=True, user={
|
150 |
+
"user_id": task.user_id,
|
151 |
+
"session_id": task.session_id,
|
152 |
+
"task_id": task.task_id,
|
153 |
+
"aworld_task": task.model_dump_json()
|
154 |
+
})
|
155 |
+
items = []
|
156 |
+
task_result = {}
|
157 |
+
if isinstance(data, AsyncGenerator):
|
158 |
+
async for item in data:
|
159 |
+
items.append(item)
|
160 |
+
if item.raw_response and item.raw_response.model_extra and item.raw_response.model_extra.get('node_id'):
|
161 |
+
if not task.node_id:
|
162 |
+
logging.info(f"submit task#{task.task_id} success. execute pod ip is [{item.raw_response.model_extra.get('node_id')}]")
|
163 |
+
task.node_id = item.raw_response.model_extra.get('node_id')
|
164 |
+
task_logger.log_task_submission(task, aworld_server, "submitted")
|
165 |
+
|
166 |
+
if item.content:
|
167 |
+
task_logger.log_task_result(task, item)
|
168 |
+
logging.info(f"task#{task.task_id} response data chunk is: {item}"[:500])
|
169 |
+
|
170 |
+
if item.raw_response and item.raw_response.model_extra and item.raw_response.model_extra.get(
|
171 |
+
'task_output_meta'):
|
172 |
+
task_result = item.raw_response.model_extra.get('task_output_meta')
|
173 |
+
|
174 |
+
|
175 |
+
elif isinstance(data, ModelResponse):
|
176 |
+
if data.raw_response and data.raw_response.model_extra and data.raw_response.model_extra.get('node_id'):
|
177 |
+
if not task.node_id:
|
178 |
+
logging.info(f"submit task#{task.task_id} success. execute pod ip is [{data.raw_response.model_extra.get('node_id')}]")
|
179 |
+
task.node_id = data.raw_response.model_extra.get('node_id')
|
180 |
+
|
181 |
+
logging.info(f"task#{task.task_id} response data is: {data}")
|
182 |
+
task_logger.log_task_result(task, data)
|
183 |
+
if data.raw_response and data.raw_response.model_extra and data.raw_response.model_extra.get('task_output_meta'):
|
184 |
+
task_result = data.raw_response.model_extra.get('task_output_meta')
|
185 |
+
|
186 |
+
result = AworldTaskResult(task=task, server_host=aworld_server, data=task_result)
|
187 |
+
task_logger.log_task_submission(task, aworld_server, "execute_finished", task_result=result)
|
188 |
+
return result
|
189 |
+
|
190 |
+
async def get_task_state(self, task_id: str):
|
191 |
+
if not isinstance(self.task_states, dict):
|
192 |
+
self.task_states = dict(self.task_states)
|
193 |
+
return self.task_states.get(task_id, None)
|
194 |
+
|
195 |
+
async def download_task_results(
|
196 |
+
self,
|
197 |
+
start_time: str = None,
|
198 |
+
end_time: str = None,
|
199 |
+
task_id: str = None,
|
200 |
+
page_size: int = 100,
|
201 |
+
save_path: str = None
|
202 |
+
) -> str:
|
203 |
+
"""
|
204 |
+
Download task results and generate a JSONL format file
|
205 |
+
|
206 |
+
Args:
|
207 |
+
start_time: Start time, format: YYYY-MM-DD HH:MM:SS
|
208 |
+
end_time: End time, format: YYYY-MM-DD HH:MM:SS
|
209 |
+
task_id: Task ID
|
210 |
+
page_size: Page size
|
211 |
+
save_path: Save path, if not specified, it will be generated automatically
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
str: Save path
|
215 |
+
"""
|
216 |
+
if not self.know_hosts:
|
217 |
+
raise ValueError("No aworld server hosts configured.")
|
218 |
+
|
219 |
+
# select server
|
220 |
+
if not hasattr(self, '_current_server_index'):
|
221 |
+
self._current_server_index = 0
|
222 |
+
aworld_server = self.know_hosts[self._current_server_index]
|
223 |
+
|
224 |
+
logging.info(f"🚀 downloading task results from server: {aworld_server}")
|
225 |
+
|
226 |
+
try:
|
227 |
+
import httpx
|
228 |
+
|
229 |
+
# build query params
|
230 |
+
params = {"page_size": page_size}
|
231 |
+
if start_time:
|
232 |
+
params["start_time"] = start_time
|
233 |
+
if end_time:
|
234 |
+
params["end_time"] = end_time
|
235 |
+
if task_id:
|
236 |
+
params["task_id"] = task_id
|
237 |
+
|
238 |
+
# send download request
|
239 |
+
async with httpx.AsyncClient(timeout=300.0) as client: # 5分钟超时
|
240 |
+
response = await client.get(
|
241 |
+
f"{aworld_server}/api/v1/tasks/download_task_results",
|
242 |
+
params=params
|
243 |
+
)
|
244 |
+
response.raise_for_status()
|
245 |
+
|
246 |
+
# if not specified save path, generate automatically
|
247 |
+
if not save_path:
|
248 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
249 |
+
save_path = f"task_results_{timestamp}.jsonl"
|
250 |
+
|
251 |
+
# ensure directory exists
|
252 |
+
save_dir = os.path.dirname(save_path) if os.path.dirname(save_path) else "."
|
253 |
+
os.makedirs(save_dir, exist_ok=True)
|
254 |
+
|
255 |
+
with open(save_path, 'wb') as f:
|
256 |
+
for chunk in response.iter_bytes():
|
257 |
+
f.write(chunk)
|
258 |
+
|
259 |
+
# calculate file size
|
260 |
+
file_size = os.path.getsize(save_path)
|
261 |
+
logging.info(f"✅ task results downloaded successfully, file: {save_path}, size: {file_size} bytes")
|
262 |
+
|
263 |
+
return save_path
|
264 |
+
|
265 |
+
except Exception as e:
|
266 |
+
logging.error(f"❌ download task results failed: {e}")
|
267 |
+
raise ValueError(f"❌ download task results failed: {str(e)}")
|
268 |
+
|
269 |
+
async def download_task_results_to_memory(
|
270 |
+
self,
|
271 |
+
start_time: str = None,
|
272 |
+
end_time: str = None,
|
273 |
+
task_id: str = None,
|
274 |
+
page_size: int = 100
|
275 |
+
) -> list:
|
276 |
+
"""
|
277 |
+
Download task results to memory, return parsed data list
|
278 |
+
|
279 |
+
Args:
|
280 |
+
start_time: Start time, format: YYYY-MM-DD HH:MM:SS
|
281 |
+
end_time: End time, format: YYYY-MM-DD HH:MM:SS
|
282 |
+
task_id: Task ID
|
283 |
+
page_size: Page size
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
list: Task results data list
|
287 |
+
"""
|
288 |
+
if not self.know_hosts:
|
289 |
+
raise ValueError("No aworld server hosts configured.")
|
290 |
+
|
291 |
+
# select server
|
292 |
+
if not hasattr(self, '_current_server_index'):
|
293 |
+
self._current_server_index = 0
|
294 |
+
aworld_server = self.know_hosts[self._current_server_index]
|
295 |
+
|
296 |
+
logging.info(f"🚀 downloading task results to memory from server: {aworld_server}")
|
297 |
+
|
298 |
+
try:
|
299 |
+
import httpx
|
300 |
+
import json
|
301 |
+
|
302 |
+
# build query params
|
303 |
+
params = {"page_size": page_size}
|
304 |
+
if start_time:
|
305 |
+
params["start_time"] = start_time
|
306 |
+
if end_time:
|
307 |
+
params["end_time"] = end_time
|
308 |
+
if task_id:
|
309 |
+
params["task_id"] = task_id
|
310 |
+
|
311 |
+
# send download request
|
312 |
+
async with httpx.AsyncClient(timeout=300.0) as client: # 5分钟超时
|
313 |
+
response = await client.get(
|
314 |
+
f"{aworld_server}/api/v1/tasks/download_task_results",
|
315 |
+
params=params
|
316 |
+
)
|
317 |
+
response.raise_for_status()
|
318 |
+
|
319 |
+
# parse jsonl content
|
320 |
+
results = []
|
321 |
+
content = response.text
|
322 |
+
if content.strip(): # check content is not empty
|
323 |
+
for line in content.strip().split('\n'):
|
324 |
+
if line.strip(): # skip empty line
|
325 |
+
try:
|
326 |
+
result_data = json.loads(line)
|
327 |
+
results.append(result_data)
|
328 |
+
except json.JSONDecodeError as e:
|
329 |
+
logging.warning(f"Failed to parse line: {line}, error: {e}")
|
330 |
+
|
331 |
+
logging.info(f"✅ task results downloaded to memory successfully, total: {len(results)} records")
|
332 |
+
|
333 |
+
return results
|
334 |
+
|
335 |
+
except Exception as e:
|
336 |
+
logging.error(f"❌ download task results to memory failed: {e}")
|
337 |
+
raise ValueError(f"❌ download task results to memory failed: {str(e)}")
|
338 |
+
|
339 |
+
def parse_task_results_file(self, file_path: str) -> list:
|
340 |
+
"""
|
341 |
+
Parse local task results jsonl file
|
342 |
+
|
343 |
+
Args:
|
344 |
+
file_path: jsonl file path
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
list: Parsed task results list
|
348 |
+
"""
|
349 |
+
import json
|
350 |
+
|
351 |
+
results = []
|
352 |
+
try:
|
353 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
354 |
+
for line_num, line in enumerate(f, 1):
|
355 |
+
line = line.strip()
|
356 |
+
if line: # 跳过空行
|
357 |
+
try:
|
358 |
+
result_data = json.loads(line)
|
359 |
+
results.append(result_data)
|
360 |
+
except json.JSONDecodeError as e:
|
361 |
+
logging.warning(f"Failed to parse line {line_num} in {file_path}: {e}")
|
362 |
+
|
363 |
+
logging.info(f"✅ parsed {len(results)} task results from {file_path}")
|
364 |
+
return results
|
365 |
+
|
366 |
+
except Exception as e:
|
367 |
+
logging.error(f"❌ failed to parse task results file {file_path}: {e}")
|
368 |
+
raise ValueError(f"❌ failed to parse task results file: {str(e)}")
|
369 |
+
|