Duibonduil commited on
Commit
cb99dbd
·
verified ·
1 Parent(s): d3b8afc

Upload aworld_client.py

Browse files
AWorld-main/aworlddistributed/aworld_client.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import traceback
5
+ from datetime import datetime
6
+ from typing import AsyncGenerator
7
+
8
+ from aworld.models.llm import acall_llm_model, get_llm_model, acall_llm_model_stream
9
+ from aworld.models.model_response import ModelResponse, LLMResponseError
10
+ from pydantic import BaseModel, Field
11
+
12
+ from base import AworldTask, AworldTaskResult, AworldTaskForm
13
+
14
+
15
+ class TaskLogger:
16
+ """Task submission logger"""
17
+
18
+ def __init__(self, log_file: str = "aworld_task_submissions.log"):
19
+ self.log_file = 'task_logs/' + log_file
20
+ self._ensure_log_file_exists()
21
+
22
+ def _ensure_log_file_exists(self):
23
+ """ensure log file exists"""
24
+ if not os.path.exists(self.log_file):
25
+ os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
26
+ with open(self.log_file, 'w', encoding='utf-8') as f:
27
+ f.write("# Aworld Task Submission Log\n")
28
+ f.write("# Format: [timestamp] task_id | agent_id | server | status | agent_answer | correct_answer | is_correct | details\n\n")
29
+
30
+ def log_task_submission(self, task: AworldTask, server: str, status: str, details: str = "", task_result: AworldTaskResult = None):
31
+ """log task submission"""
32
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
33
+ log_entry = f"[{timestamp}] {task.task_id} | {task.agent_id} | {task.node_id} | {status} | { task_result.data.get('agent_answer') if task_result and task_result.data else None } | {task_result.data.get('correct_answer') if task_result and task_result.data else None} | {task_result.data.get('gaia_correct') if task_result and task_result.data else None} |{details}\n"
34
+
35
+ try:
36
+ with open(self.log_file, 'a', encoding='utf-8') as f:
37
+ f.write(log_entry)
38
+ except Exception as e:
39
+ logging.error(f"Failed to write task submission log: {e}")
40
+
41
+ def log_task_result(self, task: AworldTask, result: ModelResponse):
42
+ """log task result to markdown file"""
43
+ try:
44
+ # create result directory
45
+ date_str = datetime.now().strftime("%Y%m%d")
46
+ result_dir = f"task_logs/result/{date_str}"
47
+ os.makedirs(result_dir, exist_ok=True)
48
+
49
+ # create markdown file
50
+ md_file = f"{result_dir}/{task.task_id}.md"
51
+
52
+ # concat content
53
+ content_parts = []
54
+ if hasattr(result, 'content') and result.content:
55
+ if isinstance(result.content, list):
56
+ content_parts.extend(result.content)
57
+ else:
58
+ content_parts.append(str(result.content))
59
+
60
+ # write to markdown file
61
+ file_exists = os.path.exists(md_file)
62
+ with open(md_file, 'a', encoding='utf-8') as f:
63
+ # only write title info when file not exists
64
+ if not file_exists:
65
+ f.write(f"# Task Result: {task.task_id}\n\n")
66
+ f.write(f"**Agent ID:** {task.agent_id}\n\n")
67
+ f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
68
+ f.write("## Content\n\n")
69
+
70
+ # write content parts
71
+ if content_parts:
72
+ for i, content in enumerate(content_parts, 1):
73
+ f.write(f"{content}\n\n")
74
+ else:
75
+ f.write("No content available.\n\n")
76
+
77
+ except Exception as e:
78
+ logging.error(f"Failed to write task result log: {e}")
79
+
80
+ task_logger = TaskLogger(log_file=f"aworld_task_submissions_{datetime.now().strftime('%Y%m%d')}.log")
81
+
82
+ class AworldTaskClient(BaseModel):
83
+ """
84
+ AworldTaskClient
85
+ """
86
+ know_hosts: list[str] = Field(default_factory=list, description="aworldserver list")
87
+ tasks: list[AworldTask] = Field(default_factory=list, description="submitted task list")
88
+ task_states: dict[str, AworldTaskResult] = Field(default_factory=dict, description="task_states")
89
+
90
+ async def submit_task(self, task: AworldTask, background: bool = True):
91
+ if not self.know_hosts:
92
+ raise ValueError("No aworld server hosts configured.")
93
+ # 1. select aworld server from know_hosts using round-robin
94
+ if not hasattr(self, '_current_server_index'):
95
+ self._current_server_index = 0
96
+ aworld_server = self.know_hosts[self._current_server_index]
97
+ if not aworld_server.startswith("http"):
98
+ aworld_server = "http://" + aworld_server
99
+ self._current_server_index = (self._current_server_index + 1) % len(self.know_hosts)
100
+
101
+ # 2. call _submit_task
102
+ result = await self._submit_task(aworld_server, task, background)
103
+ # 3. update task_states
104
+ self.task_states[task.task_id] = result
105
+
106
+
107
+ async def _submit_task(self, aworld_server, task: AworldTask, background: bool = True):
108
+ try:
109
+ logging.info(f"submit task#{task.task_id} to cluster#[{aworld_server}]")
110
+ if not background:
111
+ task_result = await self._submit_task_to_server(aworld_server, task)
112
+ else:
113
+ task_result = await self._async_submit_task_to_server(aworld_server, task)
114
+ return task_result
115
+ except Exception as e:
116
+ if isinstance(e, LLMResponseError):
117
+ if e.message and 'peer closed connection without sending complete message body (incomplete chunked read)' == e.message:
118
+ task_logger.log_task_submission(task, aworld_server, "server_close_connection", str(e))
119
+ logging.error(f"execute task to {task.node_id} server_close_connection: [{e}], please see replays wait a moment")
120
+ return
121
+ traceback.print_exc()
122
+ logging.error(f"execute task to {task.node_id} execute_failed: [{e}], please see logs from server ")
123
+ task_logger.log_task_submission(task, aworld_server, "execute_failed", str(e))
124
+
125
+ async def _async_submit_task_to_server(self, aworld_server, task: AworldTask):
126
+ import httpx
127
+ from base import AworldTaskForm, AworldTaskResult
128
+ # 构建 AworldTaskForm
129
+ form_data = AworldTaskForm(task=task)
130
+ async with httpx.AsyncClient() as client:
131
+ resp = await client.post(f"{aworld_server}/api/v1/tasks/submit_task", json=form_data.model_dump())
132
+ resp.raise_for_status()
133
+ data = resp.json()
134
+ task_logger.log_task_submission(task, aworld_server, "submitted")
135
+ return AworldTaskResult(**data)
136
+
137
+ async def _submit_task_to_server(self, aworld_server, task: AworldTask):
138
+ # build params
139
+ llm_model = get_llm_model(
140
+ llm_provider="openai",
141
+ model_name=task.agent_id,
142
+ base_url=f"{aworld_server}/v1",
143
+ api_key="0p3n-w3bu!"
144
+ )
145
+ messages = [
146
+ {"role": "user", "content": task.agent_input}
147
+ ]
148
+ #call_llm_model
149
+ data = acall_llm_model_stream(llm_model, messages, stream=True, user={
150
+ "user_id": task.user_id,
151
+ "session_id": task.session_id,
152
+ "task_id": task.task_id,
153
+ "aworld_task": task.model_dump_json()
154
+ })
155
+ items = []
156
+ task_result = {}
157
+ if isinstance(data, AsyncGenerator):
158
+ async for item in data:
159
+ items.append(item)
160
+ if item.raw_response and item.raw_response.model_extra and item.raw_response.model_extra.get('node_id'):
161
+ if not task.node_id:
162
+ logging.info(f"submit task#{task.task_id} success. execute pod ip is [{item.raw_response.model_extra.get('node_id')}]")
163
+ task.node_id = item.raw_response.model_extra.get('node_id')
164
+ task_logger.log_task_submission(task, aworld_server, "submitted")
165
+
166
+ if item.content:
167
+ task_logger.log_task_result(task, item)
168
+ logging.info(f"task#{task.task_id} response data chunk is: {item}"[:500])
169
+
170
+ if item.raw_response and item.raw_response.model_extra and item.raw_response.model_extra.get(
171
+ 'task_output_meta'):
172
+ task_result = item.raw_response.model_extra.get('task_output_meta')
173
+
174
+
175
+ elif isinstance(data, ModelResponse):
176
+ if data.raw_response and data.raw_response.model_extra and data.raw_response.model_extra.get('node_id'):
177
+ if not task.node_id:
178
+ logging.info(f"submit task#{task.task_id} success. execute pod ip is [{data.raw_response.model_extra.get('node_id')}]")
179
+ task.node_id = data.raw_response.model_extra.get('node_id')
180
+
181
+ logging.info(f"task#{task.task_id} response data is: {data}")
182
+ task_logger.log_task_result(task, data)
183
+ if data.raw_response and data.raw_response.model_extra and data.raw_response.model_extra.get('task_output_meta'):
184
+ task_result = data.raw_response.model_extra.get('task_output_meta')
185
+
186
+ result = AworldTaskResult(task=task, server_host=aworld_server, data=task_result)
187
+ task_logger.log_task_submission(task, aworld_server, "execute_finished", task_result=result)
188
+ return result
189
+
190
+ async def get_task_state(self, task_id: str):
191
+ if not isinstance(self.task_states, dict):
192
+ self.task_states = dict(self.task_states)
193
+ return self.task_states.get(task_id, None)
194
+
195
+ async def download_task_results(
196
+ self,
197
+ start_time: str = None,
198
+ end_time: str = None,
199
+ task_id: str = None,
200
+ page_size: int = 100,
201
+ save_path: str = None
202
+ ) -> str:
203
+ """
204
+ Download task results and generate a JSONL format file
205
+
206
+ Args:
207
+ start_time: Start time, format: YYYY-MM-DD HH:MM:SS
208
+ end_time: End time, format: YYYY-MM-DD HH:MM:SS
209
+ task_id: Task ID
210
+ page_size: Page size
211
+ save_path: Save path, if not specified, it will be generated automatically
212
+
213
+ Returns:
214
+ str: Save path
215
+ """
216
+ if not self.know_hosts:
217
+ raise ValueError("No aworld server hosts configured.")
218
+
219
+ # select server
220
+ if not hasattr(self, '_current_server_index'):
221
+ self._current_server_index = 0
222
+ aworld_server = self.know_hosts[self._current_server_index]
223
+
224
+ logging.info(f"🚀 downloading task results from server: {aworld_server}")
225
+
226
+ try:
227
+ import httpx
228
+
229
+ # build query params
230
+ params = {"page_size": page_size}
231
+ if start_time:
232
+ params["start_time"] = start_time
233
+ if end_time:
234
+ params["end_time"] = end_time
235
+ if task_id:
236
+ params["task_id"] = task_id
237
+
238
+ # send download request
239
+ async with httpx.AsyncClient(timeout=300.0) as client: # 5分钟超时
240
+ response = await client.get(
241
+ f"{aworld_server}/api/v1/tasks/download_task_results",
242
+ params=params
243
+ )
244
+ response.raise_for_status()
245
+
246
+ # if not specified save path, generate automatically
247
+ if not save_path:
248
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
249
+ save_path = f"task_results_{timestamp}.jsonl"
250
+
251
+ # ensure directory exists
252
+ save_dir = os.path.dirname(save_path) if os.path.dirname(save_path) else "."
253
+ os.makedirs(save_dir, exist_ok=True)
254
+
255
+ with open(save_path, 'wb') as f:
256
+ for chunk in response.iter_bytes():
257
+ f.write(chunk)
258
+
259
+ # calculate file size
260
+ file_size = os.path.getsize(save_path)
261
+ logging.info(f"✅ task results downloaded successfully, file: {save_path}, size: {file_size} bytes")
262
+
263
+ return save_path
264
+
265
+ except Exception as e:
266
+ logging.error(f"❌ download task results failed: {e}")
267
+ raise ValueError(f"❌ download task results failed: {str(e)}")
268
+
269
+ async def download_task_results_to_memory(
270
+ self,
271
+ start_time: str = None,
272
+ end_time: str = None,
273
+ task_id: str = None,
274
+ page_size: int = 100
275
+ ) -> list:
276
+ """
277
+ Download task results to memory, return parsed data list
278
+
279
+ Args:
280
+ start_time: Start time, format: YYYY-MM-DD HH:MM:SS
281
+ end_time: End time, format: YYYY-MM-DD HH:MM:SS
282
+ task_id: Task ID
283
+ page_size: Page size
284
+
285
+ Returns:
286
+ list: Task results data list
287
+ """
288
+ if not self.know_hosts:
289
+ raise ValueError("No aworld server hosts configured.")
290
+
291
+ # select server
292
+ if not hasattr(self, '_current_server_index'):
293
+ self._current_server_index = 0
294
+ aworld_server = self.know_hosts[self._current_server_index]
295
+
296
+ logging.info(f"🚀 downloading task results to memory from server: {aworld_server}")
297
+
298
+ try:
299
+ import httpx
300
+ import json
301
+
302
+ # build query params
303
+ params = {"page_size": page_size}
304
+ if start_time:
305
+ params["start_time"] = start_time
306
+ if end_time:
307
+ params["end_time"] = end_time
308
+ if task_id:
309
+ params["task_id"] = task_id
310
+
311
+ # send download request
312
+ async with httpx.AsyncClient(timeout=300.0) as client: # 5分钟超时
313
+ response = await client.get(
314
+ f"{aworld_server}/api/v1/tasks/download_task_results",
315
+ params=params
316
+ )
317
+ response.raise_for_status()
318
+
319
+ # parse jsonl content
320
+ results = []
321
+ content = response.text
322
+ if content.strip(): # check content is not empty
323
+ for line in content.strip().split('\n'):
324
+ if line.strip(): # skip empty line
325
+ try:
326
+ result_data = json.loads(line)
327
+ results.append(result_data)
328
+ except json.JSONDecodeError as e:
329
+ logging.warning(f"Failed to parse line: {line}, error: {e}")
330
+
331
+ logging.info(f"✅ task results downloaded to memory successfully, total: {len(results)} records")
332
+
333
+ return results
334
+
335
+ except Exception as e:
336
+ logging.error(f"❌ download task results to memory failed: {e}")
337
+ raise ValueError(f"❌ download task results to memory failed: {str(e)}")
338
+
339
+ def parse_task_results_file(self, file_path: str) -> list:
340
+ """
341
+ Parse local task results jsonl file
342
+
343
+ Args:
344
+ file_path: jsonl file path
345
+
346
+ Returns:
347
+ list: Parsed task results list
348
+ """
349
+ import json
350
+
351
+ results = []
352
+ try:
353
+ with open(file_path, 'r', encoding='utf-8') as f:
354
+ for line_num, line in enumerate(f, 1):
355
+ line = line.strip()
356
+ if line: # 跳过空行
357
+ try:
358
+ result_data = json.loads(line)
359
+ results.append(result_data)
360
+ except json.JSONDecodeError as e:
361
+ logging.warning(f"Failed to parse line {line_num} in {file_path}: {e}")
362
+
363
+ logging.info(f"✅ parsed {len(results)} task results from {file_path}")
364
+ return results
365
+
366
+ except Exception as e:
367
+ logging.error(f"❌ failed to parse task results file {file_path}: {e}")
368
+ raise ValueError(f"❌ failed to parse task results file: {str(e)}")
369
+