Duibonduil commited on
Commit
fdfa85c
·
verified ·
1 Parent(s): ca41391

Upload tasks.py

Browse files
AWorld-main/aworlddistributed/aworldspace/tasks.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from typing import AsyncGenerator, Optional, List
6
+
7
+ from aworld.utils.common import get_local_ip
8
+ from fastapi import APIRouter, Query, Response
9
+ from fastapi.responses import StreamingResponse
10
+
11
+ import logging
12
+ import traceback
13
+ from asyncio import Queue
14
+ import asyncio
15
+
16
+ from aworld.models.model_response import ModelResponse
17
+ from pydantic import BaseModel, Field, PrivateAttr
18
+
19
+ from aworldspace.db.db import AworldTaskDB, SqliteTaskDB, PostgresTaskDB
20
+ from aworldspace.utils.job import generate_openai_chat_completion, call_pipeline
21
+ from aworldspace.utils.log import task_logger
22
+ from base import AworldTask, AworldTaskResult, OpenAIChatCompletionForm, OpenAIChatMessage, AworldTaskForm
23
+
24
+
25
+ from config import ROOT_DIR
26
+
27
+ __STOP_TASK__ = object()
28
+
29
+
30
+
31
+
32
+ class AworldTaskExecutor(BaseModel):
33
+ """
34
+ task executor
35
+ - load task from db and execute task in a loop
36
+ - use semaphore to limit concurrent tasks
37
+ """
38
+ _task_db: AworldTaskDB = PrivateAttr()
39
+ _tasks: Queue = PrivateAttr()
40
+ max_concurrent: int = Field(default=os.environ.get("AWORLD_MAX_CONCURRENT_TASKS", 2), description="max concurrent tasks")
41
+
42
+ def __init__(self, task_db: AworldTaskDB):
43
+ super().__init__()
44
+ self._task_db = task_db
45
+ self._tasks = Queue()
46
+ self._semaphore = asyncio.BoundedSemaphore(self.max_concurrent)
47
+
48
+ async def start(self):
49
+ """
50
+ execute task in a loop
51
+ """
52
+ await asyncio.sleep(5)
53
+ logging.info(f"🚀[task executor] start, max concurrent is {self.max_concurrent}")
54
+ while True:
55
+ # load task if queue is empty and semaphore is not full
56
+ if self._tasks.empty():
57
+ await self.load_task()
58
+ task = await self._tasks.get()
59
+ if not task:
60
+ logging.info("task is none")
61
+ continue
62
+ if task == __STOP_TASK__:
63
+ logging.info("✅[task executor] stop, all tasks finished")
64
+ break
65
+ # acquire semaphore
66
+ await self._semaphore.acquire()
67
+ asyncio.create_task(self._run_task_and_release_semaphore(task))
68
+
69
+
70
+ async def stop(self):
71
+ logging.info("🛑 task executor stop, wait for all tasks to finish")
72
+ await self._tasks.put(__STOP_TASK__)
73
+
74
+ async def _run_task_and_release_semaphore(self, task: AworldTask):
75
+ """
76
+ execute task and release semaphore when done
77
+ """
78
+ start_time = time.time()
79
+ logging.info(f"🚀[task executor] execute task#{task.task_id} start, lock acquired")
80
+ try:
81
+ await self.execute_task(task)
82
+ finally:
83
+ # release semaphore
84
+ self._semaphore.release()
85
+ logging.info(f"✅[task executor] execute task#{task.task_id} success, use time {time.time() - start_time:.2f}s")
86
+
87
+ async def load_task(self):
88
+ interval = os.environ.get("AWORLD_TASK_LOAD_INTERVAL", 10)
89
+ # calculate the number of tasks to load
90
+ need_load = self._semaphore._value
91
+ if need_load <= 0:
92
+ logging.info(f"🔍[task executor] runner is busy, wait {interval}s and retry")
93
+ await asyncio.sleep(interval)
94
+ return await self.load_task()
95
+ tasks = await self._task_db.query_tasks_by_status(status="INIT", nums=need_load)
96
+ logging.info(f"🔍[task executor] load {len(tasks)} tasks from db (need {need_load})")
97
+
98
+
99
+ if not tasks or len(tasks) == 0:
100
+ logging.info(f"🔍[task executor] no task to load, wait {interval}s and retry")
101
+ await asyncio.sleep(interval)
102
+ return await self.load_task()
103
+ for task in tasks:
104
+ task.mark_running()
105
+ await self._task_db.update_task(task)
106
+ await self._tasks.put(task)
107
+ return True
108
+
109
+ async def execute_task(self, task: AworldTask):
110
+ """
111
+ execute task
112
+ """
113
+ try:
114
+ result = await self._execute_task(task)
115
+ task.mark_success()
116
+ await self._task_db.update_task(task)
117
+ await self._task_db.save_task_result(result)
118
+ task_logger.log_task_submission(task, "execute_finished", task_result=result)
119
+ except Exception as err:
120
+ task.mark_failed()
121
+ await self._task_db.update_task(task)
122
+ traceback.print_exc()
123
+ task_logger.log_task_submission(task, "execute_failed", details=f"err is {err}")
124
+
125
+ async def _execute_task(self, task: AworldTask):
126
+
127
+ # build params
128
+ messages = [
129
+ OpenAIChatMessage(role="user", content=task.agent_input)
130
+ ]
131
+ # call_llm_model
132
+ form_data = OpenAIChatCompletionForm(
133
+ model=task.agent_id,
134
+ messages=messages,
135
+ stream=True,
136
+ user={
137
+ "user_id": task.user_id,
138
+ "session_id": task.session_id,
139
+ "task_id": task.task_id,
140
+ "aworld_task": task.model_dump_json()
141
+ }
142
+ )
143
+ data = await generate_openai_chat_completion(form_data)
144
+ task_result = {}
145
+ task.node_id = get_local_ip()
146
+ items = []
147
+ md_file = ""
148
+ if data.body_iterator:
149
+ if isinstance(data.body_iterator, AsyncGenerator):
150
+
151
+ async for item_content in data.body_iterator:
152
+ async def parse_item(_item_content) -> Optional[ModelResponse]:
153
+ if item_content == "data: [DONE]":
154
+ return None
155
+ return ModelResponse.from_openai_stream_chunk(json.loads(item_content.replace("data:", "")))
156
+
157
+ # if isinstance(item, ModelResponse)
158
+ item = await parse_item(item_content)
159
+ items.append(item)
160
+ if not item:
161
+ continue
162
+
163
+ if item.content:
164
+ md_file = task_logger.log_task_result(task, item)
165
+ logging.info(f"task#{task.task_id} response data chunk is: {item}"[:500])
166
+
167
+ if item.raw_response and item.raw_response and isinstance(item.raw_response, dict) and item.raw_response.get('task_output_meta'):
168
+ task_result = item.raw_response.get('task_output_meta')
169
+
170
+ data = {
171
+ "task_result": task_result,
172
+ "md_file": md_file,
173
+ "replays_file": f"trace_data/{datetime.now().strftime('%Y%m%d')}/{get_local_ip()}/replays/task_replay_{task.task_id}.json"
174
+ }
175
+ result = AworldTaskResult(task=task, server_host=get_local_ip(), data=data)
176
+ return result
177
+
178
+
179
+ class AworldTaskManager(BaseModel):
180
+ _task_db: AworldTaskDB = PrivateAttr()
181
+ _task_executor: AworldTaskExecutor = PrivateAttr()
182
+
183
+ def __init__(self, task_db: AworldTaskDB):
184
+ super().__init__()
185
+ self._task_db = task_db
186
+ self._task_executor = AworldTaskExecutor(task_db=self._task_db)
187
+
188
+ async def start_task_executor(self):
189
+ asyncio.create_task(self._task_executor.start())
190
+
191
+ async def stop_task_executor(self):
192
+ self._task_executor.tasks.put_nowait(None)
193
+
194
+ async def submit_task(self, task: AworldTask):
195
+ # save to db
196
+ await self._task_db.insert_task(task)
197
+ # log it
198
+ task_logger.log_task_submission(task, status="init")
199
+
200
+ return AworldTaskResult(task = task)
201
+
202
+ async def load_one_unfinished_task(self) -> Optional[AworldTask]:
203
+ tasks = await self._task_db.query_tasks_by_status(status="INIT", nums=1)
204
+ if not tasks or len(tasks) == 0:
205
+ return None
206
+
207
+ cur_task = tasks[0]
208
+ cur_task.mark_running()
209
+ await self._task_db.update_task(cur_task)
210
+ # from db load one task by locked and mark task running
211
+ return cur_task
212
+
213
+ async def get_task_result(self, task_id: str) -> Optional[AworldTaskResult]:
214
+ task = await self._task_db.query_task_by_id(task_id)
215
+ if task:
216
+ task_result = await self._task_db.query_latest_task_result_by_id(task_id)
217
+ if task_result:
218
+ return task_result
219
+ return AworldTaskResult(task=task)
220
+
221
+ async def get_batch_task_results(self, task_ids: List[str]) -> List[dict]:
222
+ """
223
+ Batch retrieve task results, returns dictionary format
224
+ Each dict contains: task (required) and task_result (may be None)
225
+ """
226
+ results = []
227
+ for task_id in task_ids:
228
+ task = await self._task_db.query_task_by_id(task_id)
229
+
230
+ if task:
231
+ task_result = await self._task_db.query_latest_task_result_by_id(task_id)
232
+
233
+ result_dict = {
234
+ "task": task,
235
+ "task_result": task_result # May be None
236
+ }
237
+ results.append(result_dict)
238
+ return results
239
+
240
+ async def query_and_download_task_results(
241
+ self,
242
+ start_time: Optional[datetime] = None,
243
+ end_time: Optional[datetime] = None,
244
+ task_id: Optional[str] = None,
245
+ page_size: int = 100
246
+ ) -> List[dict]:
247
+ """
248
+ Query tasks and get results, support time range and task_id filtering
249
+ """
250
+ all_results = []
251
+ page_num = 1
252
+
253
+ while True:
254
+ # Build query filter conditions
255
+ filter_dict = {}
256
+ if start_time:
257
+ filter_dict['start_time'] = start_time
258
+ if end_time:
259
+ filter_dict['end_time'] = end_time
260
+ if task_id:
261
+ filter_dict['task_id'] = task_id
262
+
263
+ # Page query tasks
264
+ page_result = await self._task_db.page_query_tasks(
265
+ filter=filter_dict,
266
+ page_size=page_size,
267
+ page_num=page_num
268
+ )
269
+
270
+ if not page_result['items']:
271
+ break
272
+
273
+ tasks = page_result['items']
274
+
275
+ for task in tasks:
276
+ # Only query task_result (may not exist)
277
+ task_result = await self._task_db.query_latest_task_result_by_id(task.task_id)
278
+
279
+ # Use task information to build results
280
+ result_data = {
281
+ "task_id": task.task_id,
282
+ "agent_id": task.agent_id,
283
+ "status": task.status,
284
+ "created_at": task.created_at.isoformat() if task.created_at else None,
285
+ "updated_at": task.updated_at.isoformat() if task.updated_at else None,
286
+ "user_id": task.user_id,
287
+ "session_id": task.session_id,
288
+ "node_id": task.node_id,
289
+ "client_id": task.client_id,
290
+ "task_data": task.model_dump(mode='json'),
291
+ "has_result": task_result is not None,
292
+ "server_host": task_result.server_host if task_result else None,
293
+ "result_data": task_result.data if task_result else None,
294
+ }
295
+ all_results.append(result_data)
296
+
297
+ if len(page_result['items']) < page_size:
298
+ break
299
+
300
+ page_num += 1
301
+
302
+ return all_results
303
+
304
+
305
+ ########################################################################################
306
+ ########################### API
307
+ ########################################################################################
308
+
309
+ router = APIRouter()
310
+
311
+ task_db_path = os.environ.get("AWORLD_TASK_DB_PATH", f"sqlite:///{ROOT_DIR}/db/aworld.db")
312
+
313
+ if task_db_path.startswith("sqlite://"):
314
+ task_db = SqliteTaskDB(db_path = task_db_path)
315
+ elif task_db_path.startswith("mysql://"):
316
+ task_db = None # todo: add mysql task db
317
+ elif task_db_path.startswith("postgresql://") or task_db_path.startswith("postgresql+"):
318
+ task_db = PostgresTaskDB(db_url=task_db_path)
319
+ else:
320
+ raise ValueError("❌ task_db_path is not a valid sqlite, mysql or postgresql path")
321
+
322
+ task_manager = AworldTaskManager(task_db)
323
+
324
+ @router.post("/submit_task")
325
+ async def submit_task(form_data: AworldTaskForm) -> Optional[AworldTaskResult]:
326
+
327
+ logging.info(f"🚀 submit task#{form_data.task.task_id} start")
328
+ if not form_data.task:
329
+ raise ValueError("task is empty")
330
+
331
+ try:
332
+ task_result = await task_manager.submit_task(form_data.task)
333
+ logging.info(f"✅ submit task#{form_data.task.task_id} success")
334
+ return task_result
335
+ except Exception as err:
336
+ traceback.print_exc()
337
+ logging.error(f"❌ submit task#{form_data.task.task_id} failed, err is {err}")
338
+ raise ValueError("❌ submit task failed, please see logs for details")
339
+
340
+
341
+ @router.get("/task_result")
342
+ async def get_task_result(task_id) -> Optional[AworldTaskResult]:
343
+ if not task_id:
344
+ raise ValueError("❌ task_id is empty")
345
+
346
+ logging.info(f"🚀 get task result#{task_id} start")
347
+ try:
348
+ task_result = await task_manager.get_task_result(task_id)
349
+ logging.info(f"✅ get task result#{task_id} success, task result is {task_result}")
350
+ return task_result
351
+ except Exception as err:
352
+ traceback.print_exc()
353
+ logging.error(f"❌ get task result#{task_id} failed, err is {err}")
354
+ raise ValueError("❌ get task result failed, please see logs for details")
355
+
356
+ @router.post("/get_batch_task_results")
357
+ async def get_batch_task_results(task_ids: List[str]) -> List[dict]:
358
+ if not task_ids or len(task_ids) == 0:
359
+ raise ValueError("❌ task_ids is empty")
360
+
361
+ logging.info(f"🚀 get batch task results start, task_ids: {task_ids}")
362
+ try:
363
+ batch_results = await task_manager.get_batch_task_results(task_ids)
364
+ logging.info(f"✅ get batch task results success, found {len(batch_results)} results")
365
+ return batch_results
366
+ except Exception as err:
367
+ traceback.print_exc()
368
+ logging.error(f"❌ get batch task results failed, err is {err}")
369
+ raise ValueError("❌ get batch task results failed, please see logs for details")
370
+
371
+ @router.get("/download_task_results")
372
+ async def download_task_results(
373
+ start_time: Optional[str] = Query(None, description="Start time, format: YYYY-MM-DD HH:MM:SS"),
374
+ end_time: Optional[str] = Query(None, description="End time, format: YYYY-MM-DD HH:MM:SS"),
375
+ task_id: Optional[str] = Query(None, description="Task ID"),
376
+ page_size: int = Query(100, description="Page size, ge=1, le=1000")
377
+ ) -> StreamingResponse:
378
+ """
379
+ Download task results, generate jsonl format file
380
+ Query parameters support: time range (based on creation time), task_id
381
+ """
382
+ logging.info(f"🚀 download task results start, start_time: {start_time}, end_time: {end_time}, task_id: {task_id}")
383
+
384
+ try:
385
+ start_datetime = None
386
+ end_datetime = None
387
+
388
+ if start_time:
389
+ try:
390
+ start_datetime = datetime.strptime(start_time, "%Y-%m-%d %H:%M:%S")
391
+ except ValueError:
392
+ raise ValueError("❌ start_time格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式")
393
+
394
+ if end_time:
395
+ try:
396
+ end_datetime = datetime.strptime(end_time, "%Y-%m-%d %H:%M:%S")
397
+ except ValueError:
398
+ raise ValueError("❌ end_time格式错误,请使用 YYYY-MM-DD HH:MM:SS 格式")
399
+
400
+ results = await task_manager.query_and_download_task_results(
401
+ start_time=start_datetime,
402
+ end_time=end_datetime,
403
+ task_id=task_id,
404
+ page_size=page_size
405
+ )
406
+
407
+ if not results:
408
+ logging.info("📄 no task results found")
409
+
410
+ def generate_empty():
411
+ yield ""
412
+
413
+ return StreamingResponse(
414
+ generate_empty(),
415
+ media_type="application/jsonl",
416
+ headers={"Content-Disposition": "attachment; filename=task_results_empty.jsonl"}
417
+ )
418
+
419
+ # Generate jsonl content
420
+ def generate_jsonl():
421
+ for result in results:
422
+ yield json.dumps(result, ensure_ascii=False) + "\n"
423
+
424
+ # Generate file name
425
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
426
+ filename = f"task_results_{timestamp}.jsonl"
427
+
428
+ logging.info(f"✅ download task results success, total: {len(results)} results")
429
+
430
+ return StreamingResponse(
431
+ generate_jsonl(),
432
+ media_type="application/jsonl",
433
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
434
+ )
435
+
436
+ except Exception as err:
437
+ traceback.print_exc()
438
+ logging.error(f"❌ download task results failed, err is {err}")
439
+ raise ValueError(f"❌ download task results failed: {str(err)}")