Duibonduil commited on
Commit
d79f338
·
verified ·
1 Parent(s): 2814685

Upload 6 files

Browse files
aworld/runners/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
aworld/runners/call_driven_runner.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import json
4
+ import time
5
+ import traceback
6
+
7
+ import aworld.trace as trace
8
+
9
+ from typing import List, Dict, Any, Tuple
10
+
11
+ from aworld.config.conf import ToolConfig
12
+ from aworld.core.agent.base import is_agent
13
+ from aworld.agents.llm_agent import Agent
14
+ from aworld.core.common import Observation, ActionModel, ActionResult
15
+ from aworld.core.context.base import Context
16
+ from aworld.core.event.base import Message
17
+ from aworld.core.tool.base import ToolFactory, Tool, AsyncTool
18
+ from aworld.core.tool.tool_desc import is_tool_by_name
19
+ from aworld.core.task import Task, TaskResponse
20
+ from aworld.logs.util import logger, color_log, Color, trace_logger
21
+ from aworld.models.model_response import ToolCall
22
+ from aworld.output.base import StepOutput, ToolResultOutput
23
+ from aworld.runners.task_runner import TaskRunner
24
+ from aworld.runners.utils import endless_detect
25
+ from aworld.sandbox import Sandbox
26
+ from aworld.tools.utils import build_observation
27
+ from aworld.utils.common import override_in_subclass
28
+ from aworld.utils.json_encoder import NumpyEncoder
29
+
30
+
31
+ def action_result_transform(message: Message, sandbox: Sandbox) -> Tuple[Observation, float, bool, bool, dict]:
32
+ action_results = message.payload
33
+ result: ActionResult = action_results[-1]
34
+ # ignore image, dom_tree attribute, need to process them from action_results in the agent.
35
+ return build_observation(container_id=sandbox.sandbox_id,
36
+ observer=result.tool_name,
37
+ ability=result.action_name,
38
+ content=result.content,
39
+ action_result=action_results), 1.0, result.is_done, result.is_done, {}
40
+
41
+
42
+ class WorkflowRunner(TaskRunner):
43
+ def __init__(self, task: Task, *args, **kwargs):
44
+ super().__init__(task=task, *args, **kwargs)
45
+
46
+ async def do_run(self, context: Context = None) -> TaskResponse:
47
+ self.max_steps = self.conf.get("max_steps", 100)
48
+ resp = await self._do_run(context)
49
+ self._task_response = resp
50
+ return resp
51
+
52
+ async def _do_run(self, context: Context = None) -> TaskResponse:
53
+ """Multi-agent sequence general process workflow.
54
+
55
+ NOTE: Use the agent's finished state(no tool calls) to control the inner loop.
56
+ Args:
57
+ observation: Observation based on env
58
+ info: Extend info by env
59
+ """
60
+ observation = self.observation
61
+ if not observation:
62
+ raise RuntimeError("no observation, check run process")
63
+
64
+ start = time.time()
65
+ msg = None
66
+ response = None
67
+
68
+ # Use trace.span to record the entire task execution process
69
+ with trace.span(f"task_execution_{self.task.id}", attributes={
70
+ "task_id": self.task.id,
71
+ "task_name": self.task.name,
72
+ "start_time": start
73
+ }) as task_span:
74
+ try:
75
+ response = await self._common_process(task_span)
76
+ except Exception as err:
77
+ logger.error(f"Runner run failed, err is {traceback.format_exc()}")
78
+ finally:
79
+ await self.outputs.mark_completed()
80
+ color_log(f"task token usage: {self.context.token_usage}",
81
+ color=Color.pink,
82
+ logger_=trace_logger)
83
+ for _, tool in self.tools.items():
84
+ if isinstance(tool, AsyncTool):
85
+ await tool.close()
86
+ else:
87
+ tool.close()
88
+ task_span.set_attributes({
89
+ "end_time": time.time(),
90
+ "duration": time.time() - start,
91
+ "error": msg
92
+ })
93
+ # todo sandbox cleanup
94
+ if self.swarm and hasattr(self.swarm, 'agents') and self.swarm.agents:
95
+ for agent_name, agent in self.swarm.agents.items():
96
+ try:
97
+ if hasattr(agent, 'sandbox') and agent.sandbox:
98
+ await agent.sandbox.cleanup()
99
+ except Exception as e:
100
+ logger.warning(f"call_driven_runner Failed to cleanup sandbox for agent {agent_name}: {e}")
101
+ return response
102
+
103
+ async def _common_process(self, task_span):
104
+ start = time.time()
105
+ step = 1
106
+ pre_agent_name = None
107
+ observation = self.observation
108
+
109
+ for idx, agent in enumerate(self.swarm.ordered_agents):
110
+ observation.from_agent_name = agent.id()
111
+ observations = [observation]
112
+ policy = None
113
+ cur_agent = agent
114
+ while step <= self.max_steps:
115
+ await self.outputs.add_output(
116
+ StepOutput.build_start_output(name=f"Step{step}", step_num=step))
117
+
118
+ terminated = False
119
+
120
+ observation = self.swarm.action_to_observation(policy, observations)
121
+ observation.from_agent_name = observation.from_agent_name or cur_agent.id()
122
+
123
+ if observation.to_agent_name and observation.to_agent_name != cur_agent.id():
124
+ cur_agent = self.swarm.agents.get(observation.to_agent_name)
125
+
126
+ exp_id = self._get_step_span_id(step, cur_agent.id())
127
+ with trace.span(f"step_execution_{exp_id}") as step_span:
128
+ try:
129
+ step_span.set_attributes({
130
+ "exp_id": exp_id,
131
+ "task_id": self.task.id,
132
+ "task_name": self.task.name,
133
+ "trace_id": trace.get_current_span().get_trace_id(),
134
+ "step": step,
135
+ "agent_id": cur_agent.id(),
136
+ "pre_agent": pre_agent_name,
137
+ "observation": json.dumps(observation.model_dump(exclude_none=True),
138
+ ensure_ascii=False,
139
+ cls=NumpyEncoder)
140
+ })
141
+ except:
142
+ pass
143
+ pre_agent_name = cur_agent.id()
144
+
145
+ if not override_in_subclass('async_policy', cur_agent.__class__, Agent):
146
+ message = cur_agent.run(observation,
147
+ step=step,
148
+ outputs=self.outputs,
149
+ stream=self.conf.get("stream", False),
150
+ exp_id=exp_id)
151
+ else:
152
+ message = await cur_agent.async_run(observation,
153
+ step=step,
154
+ outputs=self.outputs,
155
+ stream=self.conf.get("stream",
156
+ False),
157
+ exp_id=exp_id)
158
+ policy = message.payload
159
+ step_span.set_attribute("actions",
160
+ json.dumps([action.model_dump() for action in policy],
161
+ ensure_ascii=False))
162
+ observation.content = None
163
+ color_log(f"{cur_agent.id()} policy: {policy}")
164
+ if not policy:
165
+ logger.warning(f"current agent {cur_agent.id()} no policy to use.")
166
+ await self.outputs.add_output(
167
+ StepOutput.build_failed_output(name=f"Step{step}",
168
+ step_num=step,
169
+ data=f"current agent {cur_agent.id()} no policy to use.")
170
+ )
171
+ await self.outputs.mark_completed()
172
+ task_span.set_attributes({
173
+ "end_time": time.time(),
174
+ "duration": time.time() - start,
175
+ "status": "failed",
176
+ "error": f"current agent {cur_agent.id()} no policy to use."
177
+ })
178
+ return TaskResponse(msg=f"current agent {cur_agent.id()} no policy to use.",
179
+ answer="",
180
+ success=False,
181
+ id=self.task.id,
182
+ time_cost=(time.time() - start),
183
+ usage=self.context.token_usage)
184
+
185
+ if is_agent(policy[0]):
186
+ status, info = await self._agent(agent, observation, policy, step)
187
+ if status == 'normal':
188
+ if info:
189
+ observations.append(observation)
190
+ elif status == 'break':
191
+ observation = self.swarm.action_to_observation(policy, observations)
192
+ if idx == len(self.swarm.ordered_agents) - 1:
193
+ return TaskResponse(
194
+ answer=observation.content,
195
+ success=True,
196
+ id=self.task.id,
197
+ time_cost=(time.time() - start),
198
+ usage=self.context.token_usage
199
+ )
200
+ break
201
+ elif status == 'return':
202
+ await self.outputs.add_output(
203
+ StepOutput.build_finished_output(name=f"Step{step}", step_num=step)
204
+ )
205
+ info.time_cost = (time.time() - start)
206
+ task_span.set_attributes({
207
+ "end_time": time.time(),
208
+ "duration": info.time_cost,
209
+ "status": "success"
210
+ })
211
+ return info
212
+ elif is_tool_by_name(policy[0].tool_name):
213
+ # todo sandbox
214
+ msg, reward, terminated = await self._tool_call(policy, observations, step,
215
+ cur_agent)
216
+ step_span.set_attribute("reward", reward)
217
+
218
+ else:
219
+ logger.warning(f"Unrecognized policy: {policy[0]}")
220
+ await self.outputs.add_output(
221
+ StepOutput.build_failed_output(name=f"Step{step}",
222
+ step_num=step,
223
+ data=f"Unrecognized policy: {policy[0]}, need to check prompt or agent / tool.")
224
+ )
225
+ await self.outputs.mark_completed()
226
+ task_span.set_attributes({
227
+ "end_time": time.time(),
228
+ "duration": time.time() - start,
229
+ "status": "failed",
230
+ "error": f"Unrecognized policy: {policy[0]}, need to check prompt or agent / tool."
231
+ })
232
+ return TaskResponse(
233
+ msg=f"Unrecognized policy: {policy[0]}, need to check prompt or agent / tool.",
234
+ answer="",
235
+ success=False,
236
+ id=self.task.id,
237
+ time_cost=(time.time() - start),
238
+ usage=self.context.token_usage
239
+ )
240
+ await self.outputs.add_output(
241
+ StepOutput.build_finished_output(name=f"Step{step}",
242
+ step_num=step, )
243
+ )
244
+ step += 1
245
+ if terminated and agent.finished:
246
+ logger.info(f"{agent.id()} finished")
247
+ if idx == len(self.swarm.ordered_agents) - 1:
248
+ return TaskResponse(
249
+ answer=observations[-1].content,
250
+ success=True,
251
+ id=self.task.id,
252
+ time_cost=(time.time() - start),
253
+ usage=self.context.token_usage
254
+ )
255
+ break
256
+
257
+ async def _agent(self, agent: Agent, observation: Observation, policy: List[ActionModel], step: int):
258
+ # only one agent, and get agent from policy
259
+ policy_for_agent = policy[0]
260
+ agent_name = policy_for_agent.tool_name
261
+ if not agent_name:
262
+ agent_name = policy_for_agent.agent_name
263
+ cur_agent: Agent = self.swarm.agents.get(agent_name)
264
+ if not cur_agent:
265
+ raise RuntimeError(f"Can not find {agent_name} agent in swarm.")
266
+
267
+ status = "normal"
268
+ if cur_agent.id() == agent.id():
269
+ # Current agent is entrance agent, means need to exit to the outer loop
270
+ logger.info(f"{cur_agent.id()} exit the loop")
271
+ status = "break"
272
+ return status, None
273
+
274
+ if agent.handoffs and agent_name not in agent.handoffs:
275
+ # Unable to hand off, exit to the outer loop
276
+ status = "return"
277
+ return status, TaskResponse(msg=f"Can not handoffs {agent_name} agent ",
278
+ answer=observation.content,
279
+ success=False,
280
+ id=self.task.id,
281
+ usage=self.context.token_usage)
282
+ # Check if current agent done
283
+ if cur_agent.finished:
284
+ cur_agent._finished = False
285
+ logger.info(f"{cur_agent.id()} agent be be handed off, so finished state reset to False.")
286
+
287
+ con = policy_for_agent.policy_info
288
+ if policy_for_agent.params and 'content' in policy_for_agent.params:
289
+ con = policy_for_agent.params['content']
290
+ if observation:
291
+ observation.content = con
292
+ else:
293
+ observation = Observation(content=con)
294
+ return status, observation
295
+ return status, None
296
+
297
+ # todo sandbox
298
+ async def _tool_call(self, policy: List[ActionModel], observations: List[Observation], step: int, agent: Agent):
299
+ msg = None
300
+ terminated = False
301
+ # group action by tool name
302
+ tool_mapping = dict()
303
+ reward = 0.0
304
+ # Directly use or use tools after creation.
305
+ for act in policy:
306
+ if not self.tools or (self.tools and act.tool_name not in self.tools):
307
+ # dynamic only use default config in module.
308
+ conf = self.tools_conf.get(act.tool_name)
309
+ tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False)
310
+ if isinstance(tool, Tool):
311
+ tool.reset()
312
+ elif isinstance(tool, AsyncTool):
313
+ await tool.reset()
314
+ tool_mapping[act.tool_name] = []
315
+ self.tools[act.tool_name] = tool
316
+ if act.tool_name not in tool_mapping:
317
+ tool_mapping[act.tool_name] = []
318
+ tool_mapping[act.tool_name].append(act)
319
+
320
+ for tool_name, action in tool_mapping.items():
321
+ # Execute action using browser tool and unpack all return values
322
+ if isinstance(self.tools[tool_name], Tool):
323
+ message = self.tools[tool_name].step(action)
324
+ elif isinstance(self.tools[tool_name], AsyncTool):
325
+ # todo sandbox
326
+ message = await self.tools[tool_name].step(action, agent=agent)
327
+ else:
328
+ logger.warning(f"Unsupported tool type: {self.tools[tool_name]}")
329
+ continue
330
+
331
+ observation, reward, terminated, _, info = message.payload
332
+ # observation, reward, terminated, _, info = action_result_transform(message, sandbox=None)
333
+ observations.append(observation)
334
+ for i, item in enumerate(action):
335
+ tool_output = ToolResultOutput(
336
+ tool_type=tool_name,
337
+ tool_name=item.tool_name,
338
+ data=observation.content,
339
+ origin_tool_call=ToolCall.from_dict({
340
+ "function": {
341
+ "name": item.action_name,
342
+ "arguments": item.params,
343
+ }
344
+ })
345
+ )
346
+ await self.outputs.add_output(tool_output)
347
+
348
+ # Check if there's an exception in info
349
+ if info.get("exception"):
350
+ color_log(f"Step {step} failed with exception: {info['exception']}", color=Color.red)
351
+ msg = f"Step {step} failed with exception: {info['exception']}"
352
+ logger.info(f"step: {step} finished by tool action: {action}.")
353
+ log_ob = Observation(content='' if observation.content is None else observation.content,
354
+ action_result=observation.action_result)
355
+ trace_logger.info(f"{tool_name} observation: {log_ob}", color=Color.green)
356
+ return msg, reward, terminated
357
+
358
+ def _get_step_span_id(self, step, cur_agent_name):
359
+ key = (step, cur_agent_name)
360
+ if key not in self.step_agent_counter:
361
+ self.step_agent_counter[key] = 0
362
+ else:
363
+ self.step_agent_counter[key] += 1
364
+ exp_index = self.step_agent_counter[key]
365
+
366
+ return f"{self.task.id}_{step}_{cur_agent_name}_{exp_index}"
367
+
368
+
369
+ class LoopWorkflowRunner(WorkflowRunner):
370
+
371
+ async def _do_run(self, context: Context = None) -> TaskResponse:
372
+ observation = self.observation
373
+ if not observation:
374
+ raise RuntimeError("no observation, check run process")
375
+
376
+ start = time.time()
377
+ step = 1
378
+ msg = None
379
+
380
+ # Use trace.span to record the entire task execution process
381
+ with trace.span(f"task_execution_{self.task.id}", attributes={
382
+ "task_id": self.task.id,
383
+ "task_name": self.task.name,
384
+ "start_time": start
385
+ }) as task_span:
386
+ try:
387
+ for i in range(self.max_steps):
388
+ await self._common_process(task_span)
389
+ step += 1
390
+ except Exception as err:
391
+ logger.error(f"Runner run failed, err is {traceback.format_exc()}")
392
+ finally:
393
+ await self.outputs.mark_completed()
394
+ color_log(f"task token usage: {self.context.token_usage}",
395
+ color=Color.pink,
396
+ logger_=trace_logger)
397
+ for _, tool in self.tools.items():
398
+ if isinstance(tool, AsyncTool):
399
+ await tool.close()
400
+ else:
401
+ tool.close()
402
+ task_span.set_attributes({
403
+ "end_time": time.time(),
404
+ "duration": time.time() - start,
405
+ "error": msg
406
+ })
407
+ return TaskResponse(msg=msg,
408
+ answer=observation.content,
409
+ success=True if not msg else False,
410
+ id=self.task.id,
411
+ time_cost=(time.time() - start),
412
+ usage=self.context.token_usage)
413
+
414
+
415
+ class HandoffRunner(TaskRunner):
416
+ def __init__(self, task: Task, *args, **kwargs):
417
+ super().__init__(task=task, *args, **kwargs)
418
+
419
+ async def do_run(self, context: Context = None) -> TaskResponse:
420
+ resp = await self._do_run(context)
421
+ self._task_response = resp
422
+ return resp
423
+
424
+ async def _do_run(self, context: Context = None) -> TaskResponse:
425
+ """Multi-agent general process based on handoff.
426
+
427
+ NOTE: Use the agent's finished state to control the loop, so the agent must carefully set finished state.
428
+
429
+ Args:
430
+ context: Context of runner.
431
+ """
432
+ start = time.time()
433
+
434
+ observation = self.observation
435
+ info = dict()
436
+ step = 0
437
+ max_steps = self.conf.get("max_steps", 100)
438
+ results = []
439
+ swarm_resp = None
440
+ self.loop_detect = []
441
+ # Use trace.span to record the entire task execution process
442
+ with trace.span(f"task_execution_{self.task.id}", attributes={
443
+ "task_id": self.task.id,
444
+ "task_name": self.task.name,
445
+ "start_time": start
446
+ }) as task_span:
447
+ try:
448
+ while step < max_steps:
449
+ # Loose protocol
450
+ result_dict = await self._process(observation=observation, info=info)
451
+ results.append(result_dict)
452
+
453
+ swarm_resp = result_dict.get("response")
454
+ logger.info(f"Step: {step} response:\n {result_dict}")
455
+
456
+ step += 1
457
+ if self.swarm.finished or endless_detect(self.loop_detect,
458
+ self.endless_threshold,
459
+ self.swarm.communicate_agent.id()):
460
+ logger.info("task done!")
461
+ break
462
+
463
+ if not swarm_resp:
464
+ logger.warning(f"Step: {step} swarm no valid response")
465
+ break
466
+
467
+ observation = result_dict.get("observation")
468
+ if not observation:
469
+ observation = Observation(content=swarm_resp)
470
+ else:
471
+ observation.content = swarm_resp
472
+
473
+ time_cost = time.time() - start
474
+ if not results:
475
+ logger.warning("task no result!")
476
+ task_span.set_attributes({
477
+ "status": "failed",
478
+ "error": f"task no result!"
479
+ })
480
+ return TaskResponse(msg=traceback.format_exc(),
481
+ answer='',
482
+ success=False,
483
+ id=self.task.id,
484
+ time_cost=time_cost,
485
+ usage=self.context.token_usage)
486
+
487
+ answer = results[-1].get('observation').content if results[-1].get('observation') else swarm_resp
488
+ return TaskResponse(answer=answer,
489
+ success=True,
490
+ id=self.task.id,
491
+ time_cost=(time.time() - start),
492
+ usage=self.context.token_usage)
493
+ except Exception as e:
494
+ logger.error(f"Task execution failed with error: {str(e)}\n{traceback.format_exc()}")
495
+ task_span.set_attributes({
496
+ "status": "failed",
497
+ "error": f"Task execution failed with error: {str(e)}\n{traceback.format_exc()}"
498
+ })
499
+ return TaskResponse(msg=traceback.format_exc(),
500
+ answer='',
501
+ success=False,
502
+ id=self.task.id,
503
+ time_cost=(time.time() - start),
504
+ usage=self.context.token_usage)
505
+ finally:
506
+ color_log(f"task token usage: {self.context.token_usage}",
507
+ color=Color.pink,
508
+ logger_=trace_logger)
509
+ for _, tool in self.tools.items():
510
+ if isinstance(tool, AsyncTool):
511
+ await tool.close()
512
+ else:
513
+ tool.close()
514
+ task_span.set_attributes({
515
+ "end_time": time.time(),
516
+ "duration": time.time() - start,
517
+ })
518
+
519
+ async def _process(self, observation, info) -> Dict[str, Any]:
520
+ if not self.swarm.initialized:
521
+ raise RuntimeError("swarm needs to use `reset` to init first.")
522
+
523
+ start = time.time()
524
+ step = 0
525
+ max_steps = self.conf.get("max_steps", 100)
526
+ self.swarm.cur_agent = self.swarm.communicate_agent
527
+ pre_agent_name = None
528
+ # use communicate agent every time
529
+ if override_in_subclass('async_policy', self.swarm.cur_agent.__class__, Agent):
530
+ message = self.swarm.cur_agent.run(observation,
531
+ step=step,
532
+ outputs=self.outputs,
533
+ stream=self.conf.get("stream", False))
534
+ else:
535
+ message = await self.swarm.cur_agent.async_run(observation,
536
+ step=step,
537
+ outputs=self.outputs,
538
+ stream=self.conf.get("stream", False))
539
+ self.loop_detect.append(self.swarm.cur_agent.id())
540
+ policy = message.payload
541
+ if not policy:
542
+ logger.warning(f"current agent {self.swarm.cur_agent.id()} no policy to use.")
543
+ exp_id = self._get_step_span_id(step, self.swarm.cur_agent.id())
544
+ with trace.span(f"step_execution_{exp_id}") as step_span:
545
+ step_span.set_attributes({
546
+ "exp_id": exp_id,
547
+ "task_id": self.task.id,
548
+ "task_name": self.task.name,
549
+ "trace_id": trace.get_current_span().get_trace_id(),
550
+ "step": step,
551
+ "agent_id": self.swarm.cur_agent.id(),
552
+ "pre_agent": pre_agent_name,
553
+ "observation": json.dumps(observation.model_dump(exclude_none=True),
554
+ ensure_ascii=False,
555
+ cls=NumpyEncoder),
556
+ "actions": json.dumps([action.model_dump() for action in policy], ensure_ascii=False)
557
+ })
558
+ return {"msg": f"current agent {self.swarm.cur_agent.id()} no policy to use.",
559
+ "steps": step,
560
+ "success": False,
561
+ "time_cost": (time.time() - start)}
562
+ color_log(f"{self.swarm.cur_agent.id()} policy: {policy}")
563
+
564
+ msg = None
565
+ response = None
566
+ return_entry = False
567
+ cur_agent = None
568
+ cur_observation = observation
569
+ finished = False
570
+ try:
571
+ while step < max_steps:
572
+ terminated = False
573
+ exp_id = self._get_step_span_id(step, self.swarm.cur_agent.id())
574
+ with trace.span(f"step_execution_{exp_id}") as step_span:
575
+ try:
576
+ step_span.set_attributes({
577
+ "exp_id": exp_id,
578
+ "task_id": self.task.id,
579
+ "task_name": self.task.name,
580
+ "trace_id": trace.get_current_span().get_trace_id(),
581
+ "step": step,
582
+ "agent_id": self.swarm.cur_agent.id(),
583
+ "pre_agent": pre_agent_name,
584
+ "observation": json.dumps(cur_observation.model_dump(exclude_none=True),
585
+ ensure_ascii=False,
586
+ cls=NumpyEncoder),
587
+ "actions": json.dumps([action.model_dump() for action in policy], ensure_ascii=False)
588
+ })
589
+ except:
590
+ pass
591
+
592
+ if is_agent(policy[0]):
593
+ status, info, ob = await self._social_agent(policy, step)
594
+ if status == 'normal':
595
+ self.swarm.cur_agent = self.swarm.agents.get(policy[0].agent_name)
596
+ policy = info
597
+
598
+ cur_observation = ob
599
+ # clear observation
600
+ observation = None
601
+ elif is_tool_by_name(policy[0].tool_name):
602
+ status, terminated, info = await self._social_tool_call(policy, step)
603
+ if status == 'normal':
604
+ observation = info
605
+ cur_observation = observation
606
+ else:
607
+ logger.warning(f"Unrecognized policy: {policy[0]}")
608
+ return {"msg": f"Unrecognized policy: {policy[0]}, need to check prompt or agent / tool.",
609
+ "response": "",
610
+ "steps": step,
611
+ "success": False}
612
+
613
+ if status == 'break':
614
+ return_entry = info
615
+ break
616
+ elif status == 'return':
617
+ return info
618
+
619
+ step += 1
620
+ pre_agent_name = self.swarm.cur_agent.id()
621
+ if terminated and self.swarm.cur_agent.finished:
622
+ logger.info(f"{self.swarm.cur_agent.id()} finished")
623
+ break
624
+
625
+ if observation:
626
+ if cur_agent is None:
627
+ cur_agent = self.swarm.cur_agent
628
+ if not override_in_subclass('async_policy', cur_agent.__class__, Agent):
629
+ message = cur_agent.run(observation,
630
+ step=step,
631
+ outputs=self.outputs,
632
+ stream=self.conf.get("stream", False))
633
+ else:
634
+ message = await cur_agent.async_run(observation,
635
+ step=step,
636
+ outputs=self.outputs,
637
+ stream=self.conf.get("stream", False))
638
+ policy = message.payload
639
+ color_log(f"{cur_agent.id()} policy: {policy}")
640
+
641
+ if policy:
642
+ response = policy[0].policy_info if policy[0].policy_info else policy[0].action_name
643
+
644
+ # All agents or tools have completed their tasks
645
+ if all(agent.finished for _, agent in self.swarm.agents.items()) or (all(
646
+ tool.finished for _, tool in self.tools.items()) and len(self.swarm.agents) == 1):
647
+ logger.info("entry agent finished, swarm process finished.")
648
+ finished = True
649
+
650
+ if return_entry and not finished:
651
+ # Return to the entrance, reset current agent finished state
652
+ self.swarm.cur_agent._finished = False
653
+ return {"steps": step,
654
+ "response": response,
655
+ "observation": observation,
656
+ "msg": msg,
657
+ "success": True if not msg else False}
658
+ except Exception as e:
659
+ logger.error(f"Task execution failed with error: {str(e)}\n{traceback.format_exc()}")
660
+ return {
661
+ "msg": str(e),
662
+ "response": "",
663
+ "traceback": traceback.format_exc(),
664
+ "steps": step,
665
+ "success": False
666
+ }
667
+
668
+ async def _social_agent(self, policy: List[ActionModel], step):
669
+ # only one agent, and get agent from policy
670
+ policy_for_agent = policy[0]
671
+ agent_name = policy_for_agent.tool_name
672
+ if not agent_name:
673
+ agent_name = policy_for_agent.agent_name
674
+
675
+ cur_agent: Agent = self.swarm.agents.get(agent_name)
676
+ if not cur_agent:
677
+ raise RuntimeError(f"Can not find {agent_name} agent in swarm.")
678
+
679
+ if cur_agent.id() == self.swarm.communicate_agent.id() or cur_agent.id() == self.swarm.cur_agent.id():
680
+ # Current agent is entrance agent, means need to exit to the outer loop
681
+ logger.info(f"{cur_agent.id()} exit to the outer loop")
682
+ return 'break', True, None
683
+
684
+ if self.swarm.cur_agent.handoffs and agent_name not in self.swarm.cur_agent.handoffs:
685
+ # Unable to hand off, exit to the outer loop
686
+ return "return", {"msg": f"Can not handoffs {agent_name} agent "
687
+ f"by {cur_agent.id()} agent.",
688
+ "response": policy[0].policy_info if policy else "",
689
+ "steps": step,
690
+ "success": False}, None
691
+ # Check if current agent done
692
+ if cur_agent.finished:
693
+ cur_agent._finished = False
694
+ logger.info(f"{cur_agent.id()} agent be be handed off, so finished state reset to False.")
695
+
696
+ observation = Observation(content=policy_for_agent.policy_info)
697
+ self.loop_detect.append(cur_agent.id())
698
+ if cur_agent.step_reset:
699
+ cur_agent.reset({"task": observation.content,
700
+ "tool_names": cur_agent.tool_names,
701
+ "agent_names": cur_agent.handoffs,
702
+ "mcp_servers": cur_agent.mcp_servers})
703
+
704
+ if not override_in_subclass('async_policy', cur_agent.__class__, Agent):
705
+ message = cur_agent.run(observation,
706
+ step=step,
707
+ outputs=self.outputs,
708
+ stream=self.conf.get("stream", False))
709
+ else:
710
+ message = await cur_agent.async_run(observation,
711
+ step=step,
712
+ outputs=self.outputs,
713
+ stream=self.conf.get("stream", False))
714
+
715
+ agent_policy = message.payload
716
+ if not agent_policy:
717
+ logger.warning(
718
+ f"{observation} can not get the valid policy in {policy_for_agent.agent_name}, exit task!")
719
+ return "return", {"msg": f"{policy_for_agent.agent_name} invalid policy",
720
+ "response": "",
721
+ "steps": step,
722
+ "success": False}, None
723
+ color_log(f"{cur_agent.id()} policy: {agent_policy}")
724
+ return 'normal', agent_policy, observation
725
+
726
+ async def _social_tool_call(self, policy: List[ActionModel], step: int):
727
+ observation = None
728
+ terminated = False
729
+ # group action by tool name
730
+ tool_mapping = dict()
731
+ # Directly use or use tools after creation.
732
+ for act in policy:
733
+ if not self.tools or (self.tools and act.tool_name not in self.tools):
734
+ # dynamic only use default config in module.
735
+ conf: ToolConfig = self.tools_conf.get(act.tool_name)
736
+ tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False)
737
+ if isinstance(tool, Tool):
738
+ tool.reset()
739
+ elif isinstance(tool, AsyncTool):
740
+ await tool.reset()
741
+
742
+ tool_mapping[act.tool_name] = []
743
+ self.tools[act.tool_name] = tool
744
+ if act.tool_name not in tool_mapping:
745
+ tool_mapping[act.tool_name] = []
746
+ tool_mapping[act.tool_name].append(act)
747
+
748
+ for tool_name, action in tool_mapping.items():
749
+ # Execute action using browser tool and unpack all return values
750
+ if isinstance(self.tools[tool_name], Tool):
751
+ message = self.tools[tool_name].step(action)
752
+ elif isinstance(self.tools[tool_name], AsyncTool):
753
+ message = await self.tools[tool_name].step(action)
754
+ else:
755
+ logger.warning(f"Unsupported tool type: {self.tools[tool_name]}")
756
+ continue
757
+
758
+ observation, reward, terminated, _, info = message.payload
759
+ for i, item in enumerate(action):
760
+ tool_output = ToolResultOutput(data=observation.content, origin_tool_call=ToolCall.from_dict({
761
+ "function": {
762
+ "name": item.action_name,
763
+ "arguments": item.params,
764
+ }
765
+ }))
766
+ await self.outputs.add_output(tool_output)
767
+
768
+ # Check if there's an exception in info
769
+ if info.get("exception"):
770
+ color_log(f"Step {step} failed with exception: {info['exception']}", color=Color.red)
771
+ logger.info(f"step: {step} finished by tool action {action}.")
772
+ log_ob = Observation(content='' if observation.content is None else observation.content,
773
+ action_result=observation.action_result)
774
+ color_log(f"{tool_name} observation: {log_ob}", color=Color.green)
775
+
776
+ # The tool results give itself, exit; give to other agents, continue
777
+ tmp_name = policy[0].agent_name
778
+ if self.swarm.cur_agent.id() == self.swarm.communicate_agent.id() and (
779
+ len(self.swarm.agents) == 1 or tmp_name is None or self.swarm.cur_agent.id() == tmp_name):
780
+ return "break", terminated, True
781
+ elif policy[0].agent_name:
782
+ policy_for_agent = policy[0]
783
+ agent_name = policy_for_agent.agent_name
784
+ if not agent_name:
785
+ agent_name = policy_for_agent.tool_name
786
+ cur_agent: Agent = self.swarm.agents.get(agent_name)
787
+ if not cur_agent:
788
+ raise RuntimeError(f"Can not find {agent_name} agent in swarm.")
789
+ if self.swarm.cur_agent.handoffs and agent_name not in self.swarm.cur_agent.handoffs:
790
+ # Unable to hand off, exit to the outer loop
791
+ return "return", {"msg": f"Can not handoffs {agent_name} agent "
792
+ f"by {cur_agent.id()} agent.",
793
+ "response": policy[0].policy_info if policy else "",
794
+ "steps": step,
795
+ "success": False}
796
+ # Check if current agent done
797
+ if cur_agent.finished:
798
+ cur_agent._finished = False
799
+ logger.info(f"{cur_agent.id()} agent be be handed off, so finished state reset to False.")
800
+ return "normal", terminated, observation
801
+
802
+ def _get_step_span_id(self, step, cur_agent_name):
803
+ key = (step, cur_agent_name)
804
+ if key not in self.step_agent_counter:
805
+ self.step_agent_counter[key] = 0
806
+ else:
807
+ self.step_agent_counter[key] += 1
808
+ exp_index = self.step_agent_counter[key]
809
+
810
+ return f"{self.task.id}_{step}_{cur_agent_name}_{exp_index}"
aworld/runners/event_runner.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import asyncio
4
+ import time
5
+ import traceback
6
+ import aworld.trace as trace
7
+ from typing import List, Callable, Any
8
+
9
+ from aworld.core.common import TaskItem
10
+ from aworld.core.context.base import Context
11
+
12
+ from aworld.agents.llm_agent import Agent
13
+ from aworld.core.event.base import Message, Constants, TopicType, ToolMessage, AgentMessage
14
+ from aworld.core.task import Task, TaskResponse
15
+ from aworld.events.manager import EventManager
16
+ from aworld.logs.util import logger
17
+ from aworld.runners.handler.agent import DefaultAgentHandler, AgentHandler
18
+ from aworld.runners.handler.base import DefaultHandler
19
+ from aworld.runners.handler.output import DefaultOutputHandler
20
+ from aworld.runners.handler.task import DefaultTaskHandler, TaskHandler
21
+ from aworld.runners.handler.tool import DefaultToolHandler, ToolHandler
22
+
23
+ from aworld.runners.task_runner import TaskRunner
24
+ from aworld.utils.common import override_in_subclass, new_instance
25
+ from aworld.runners.state_manager import EventRuntimeStateManager
26
+
27
+
28
+ class TaskEventRunner(TaskRunner):
29
+ """Event driven task runner."""
30
+
31
+ def __init__(self, task: Task, *args, **kwargs):
32
+ super().__init__(task, *args, **kwargs)
33
+ self._task_response = None
34
+ self.event_mng = EventManager(self.context)
35
+ self.hooks = {}
36
+ self.background_tasks = set()
37
+ self.state_manager = EventRuntimeStateManager.instance()
38
+
39
+ async def pre_run(self):
40
+ await super().pre_run()
41
+
42
+ if self.swarm and not self.swarm.max_steps:
43
+ self.swarm.max_steps = self.task.conf.get('max_steps', 10)
44
+ observation = self.observation
45
+ if not observation:
46
+ raise RuntimeError("no observation, check run process")
47
+
48
+ self._build_first_message()
49
+
50
+ if self.swarm:
51
+ # register agent handler
52
+ for _, agent in self.swarm.agents.items():
53
+ agent.set_tools_instances(self.tools, self.tools_conf)
54
+ if agent.handler:
55
+ await self.event_mng.register(Constants.AGENT, agent.id(), agent.handler)
56
+ else:
57
+ if override_in_subclass('async_policy', agent.__class__, Agent):
58
+ await self.event_mng.register(Constants.AGENT, agent.id(), agent.async_run)
59
+ else:
60
+ await self.event_mng.register(Constants.AGENT, agent.id(), agent.run)
61
+ # register tool handler
62
+ for key, tool in self.tools.items():
63
+ if tool.handler:
64
+ await self.event_mng.register(Constants.TOOL, tool.name(), tool.handler)
65
+ else:
66
+ await self.event_mng.register(Constants.TOOL, tool.name(), tool.step)
67
+ handlers = self.event_mng.event_bus.get_topic_handlers(
68
+ Constants.TOOL, tool.name())
69
+ if not handlers:
70
+ await self.event_mng.register(Constants.TOOL, Constants.TOOL, tool.step)
71
+
72
+ self._stopped = asyncio.Event()
73
+
74
+ # handler of process in framework
75
+ handler_list = self.conf.get("handlers")
76
+ if handler_list:
77
+ handlers = []
78
+ for hand in handler_list:
79
+ handlers.append(new_instance(hand, self))
80
+
81
+ has_task_handler = False
82
+ has_tool_handler = False
83
+ has_agent_handler = False
84
+ for hand in handlers:
85
+ if isinstance(hand, TaskHandler):
86
+ has_task_handler = True
87
+ elif isinstance(hand, ToolHandler):
88
+ has_tool_handler = True
89
+ elif isinstance(hand, AgentHandler):
90
+ has_agent_handler = True
91
+
92
+ if not has_agent_handler:
93
+ self.handlers.append(DefaultAgentHandler(runner=self))
94
+ if not has_tool_handler:
95
+ self.handlers.append(DefaultToolHandler(runner=self))
96
+ if not has_task_handler:
97
+ self.handlers.append(DefaultTaskHandler(runner=self))
98
+ self.handlers = handlers
99
+ else:
100
+ self.handlers = [DefaultAgentHandler(runner=self),
101
+ DefaultToolHandler(runner=self),
102
+ DefaultTaskHandler(runner=self),
103
+ DefaultOutputHandler(runner=self)]
104
+
105
+ def _build_first_message(self):
106
+ # build the first message
107
+ if self.agent_oriented:
108
+ self.init_message = AgentMessage(payload=self.observation,
109
+ sender='runner',
110
+ receiver=self.swarm.communicate_agent.id(),
111
+ session_id=self.context.session_id,
112
+ headers={'context': self.context})
113
+ else:
114
+ actions = self.observation.content
115
+ receiver = actions[0].tool_name
116
+ self.init_message = ToolMessage(payload=self.observation.content,
117
+ sender='runner',
118
+ receiver=receiver,
119
+ session_id=self.context.session_id,
120
+ headers={'context': self.context})
121
+
122
+ async def _common_process(self, message: Message) -> List[Message]:
123
+ event_bus = self.event_mng.event_bus
124
+
125
+ key = message.category
126
+ transformer = event_bus.get_transform_handlers(key)
127
+ if transformer:
128
+ message = await event_bus.transform(message, handler=transformer)
129
+
130
+ results = []
131
+ handlers = event_bus.get_handlers(key)
132
+ async with trace.message_span(message=message):
133
+ self.state_manager.start_message_node(message)
134
+ if handlers:
135
+ if message.topic:
136
+ handlers = {message.topic: handlers.get(message.topic, [])}
137
+ elif message.receiver:
138
+ handlers = {message.receiver: handlers.get(
139
+ message.receiver, [])}
140
+
141
+ for topic, handler_list in handlers.items():
142
+ if not handler_list:
143
+ logger.warning(f"{topic} no handler, ignore.")
144
+ continue
145
+
146
+ for handler in handler_list:
147
+ t = asyncio.create_task(
148
+ self._handle_task(message, handler))
149
+ self.background_tasks.add(t)
150
+ t.add_done_callback(self.background_tasks.discard)
151
+ else:
152
+ # not handler, return raw message
153
+ results.append(message)
154
+
155
+ t = asyncio.create_task(self._raw_task(results))
156
+ self.background_tasks.add(t)
157
+ t.add_done_callback(self.background_tasks.discard)
158
+ # wait until it is complete
159
+ await t
160
+ self.state_manager.end_message_node(message)
161
+ return results
162
+
163
+ async def _handle_task(self, message: Message, handler: Callable[..., Any]):
164
+ con = message
165
+ async with trace.span(handler.__name__):
166
+ try:
167
+ logger.info(
168
+ f"event_runner _handle_task start, message: {message.id}")
169
+ if asyncio.iscoroutinefunction(handler):
170
+ con = await handler(con)
171
+ else:
172
+ con = handler(con)
173
+
174
+ logger.info(f"event_runner _handle_task message= {message.id}")
175
+ if isinstance(con, Message):
176
+ # process in framework
177
+ self.state_manager.save_message_handle_result(name=handler.__name__,
178
+ message=message,
179
+ result=con)
180
+ async for event in self._inner_handler_process(
181
+ results=[con],
182
+ handlers=self.handlers
183
+ ):
184
+ await self.event_mng.emit_message(event)
185
+ else:
186
+ self.state_manager.save_message_handle_result(name=handler.__name__,
187
+ message=message)
188
+ except Exception as e:
189
+ logger.warning(
190
+ f"{handler} process fail. {traceback.format_exc()}")
191
+ error_msg = Message(
192
+ category=Constants.TASK,
193
+ payload=TaskItem(msg=str(e), data=message),
194
+ sender=self.name,
195
+ session_id=Context.instance().session_id,
196
+ topic=TopicType.ERROR
197
+ )
198
+ self.state_manager.save_message_handle_result(name=handler.__name__,
199
+ message=message,
200
+ result=error_msg)
201
+ await self.event_mng.event_bus.publish(error_msg)
202
+
203
+ async def _raw_task(self, messages: List[Message]):
204
+ # process in framework
205
+ async for event in self._inner_handler_process(
206
+ results=messages,
207
+ handlers=self.handlers
208
+ ):
209
+ await self.event_mng.emit_message(event)
210
+
211
+ async def _inner_handler_process(self, results: List[Message], handlers: List[DefaultHandler]):
212
+ # can use runtime backend to parallel
213
+ for handler in handlers:
214
+ for result in results:
215
+ async for event in handler.handle(result):
216
+ yield event
217
+
218
+ async def _do_run(self):
219
+ """Task execution process in real."""
220
+ start = time.time()
221
+ msg = None
222
+ answer = None
223
+
224
+ try:
225
+ while True:
226
+ if await self.is_stopped():
227
+ await self.event_mng.done()
228
+ logger.info("stop task...")
229
+ if self._task_response is None:
230
+ # send msg to output
231
+ self._task_response = TaskResponse(msg=msg,
232
+ answer=answer,
233
+ success=True if not msg else False,
234
+ id=self.task.id,
235
+ time_cost=(
236
+ time.time() - start),
237
+ usage=self.context.token_usage)
238
+ break
239
+
240
+ # consume message
241
+ message: Message = await self.event_mng.consume()
242
+
243
+ # use registered handler to process message
244
+ await self._common_process(message)
245
+ except Exception as e:
246
+ logger.error(f"consume message fail. {traceback.format_exc()}")
247
+ finally:
248
+ if await self.is_stopped():
249
+ await self.task.outputs.mark_completed()
250
+ # todo sandbox cleanup
251
+ if self.swarm and hasattr(self.swarm, 'agents') and self.swarm.agents:
252
+ for agent_name, agent in self.swarm.agents.items():
253
+ try:
254
+ if hasattr(agent, 'sandbox') and agent.sandbox:
255
+ await agent.sandbox.cleanup()
256
+ except Exception as e:
257
+ logger.warning(
258
+ f"event_runner Failed to cleanup sandbox for agent {agent_name}: {e}")
259
+
260
+ async def do_run(self, context: Context = None):
261
+ if self.swarm and not self.swarm.initialized:
262
+ raise RuntimeError("swarm needs to use `reset` to init first.")
263
+ async with trace.span("Task_" + self.init_message.session_id):
264
+ await self.event_mng.emit_message(self.init_message)
265
+ await self._do_run()
266
+ return self._task_response
267
+
268
+ async def stop(self):
269
+ self._stopped.set()
270
+
271
+ async def is_stopped(self):
272
+ return self._stopped.is_set()
273
+
274
+ def response(self):
275
+ return self._task_response
aworld/runners/state_manager.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List
5
+ from aworld.core.event.base import Message
6
+ from enum import Enum
7
+ from abc import ABC, abstractmethod, ABCMeta
8
+ from aworld.core.agent.base import is_agent_by_name
9
+ from aworld.core.tool.tool_desc import is_tool_by_name
10
+ from aworld.core.singleton import InheritanceSingleton, SingletonMeta
11
+ from aworld.core.event.base import Constants
12
+ from aworld.logs.util import logger
13
+
14
+
15
+ class RunNodeBusiType(Enum):
16
+ AGENT = 'AGENT'
17
+ TOOL = 'TOOL'
18
+ TASK = 'TASK'
19
+
20
+ @staticmethod
21
+ def from_message_category(category: str) -> 'RunNodeBusiType':
22
+ if category == Constants.AGENT:
23
+ return RunNodeBusiType.AGENT
24
+ if category == Constants.TOOL:
25
+ return RunNodeBusiType.TOOL
26
+ if category == Constants.TASK:
27
+ return RunNodeBusiType.TASK
28
+ return None
29
+
30
+
31
+ class RunNodeStatus(Enum):
32
+ INIT = 'INIT'
33
+ RUNNING = 'RUNNING'
34
+ BREAKED = 'BREAKED'
35
+ SUCCESS = 'SUCCESS'
36
+ FAILED = 'FAILED'
37
+ TIMEOUNT = 'TIMEOUNT'
38
+
39
+
40
+ class HandleResult(BaseModel):
41
+ name: str = None
42
+ status: RunNodeStatus = None
43
+ result_msg: Optional[str] = None
44
+ result: Optional[Message] = None
45
+
46
+
47
+ class RunNode(BaseModel):
48
+ # {busi_id}_{busi_type}
49
+ node_id: Optional[str] = None
50
+ busi_type: str = None
51
+ busi_id: str = None
52
+ session_id: str = None
53
+ msg_id: Optional[str] = None # input message id
54
+ # busi_id of node that send the input message
55
+ msg_from: Optional[str] = None
56
+ parent_node_id: Optional[str] = None
57
+ status: RunNodeStatus = None
58
+ result_msg: Optional[str] = None
59
+ results: Optional[List[HandleResult]] = None
60
+ create_time: Optional[float] = None
61
+ execute_time: Optional[float] = None
62
+ end_time: Optional[float] = None
63
+
64
+
65
+ class StateStorage(ABC):
66
+ @abstractmethod
67
+ def get(self, node_id: str) -> RunNode:
68
+ pass
69
+
70
+ @abstractmethod
71
+ def insert(self, node: RunNode):
72
+ pass
73
+
74
+ @abstractmethod
75
+ def update(self, node: RunNode):
76
+ pass
77
+
78
+ @abstractmethod
79
+ def query(self, session_id: str) -> List[RunNode]:
80
+ pass
81
+
82
+
83
+ class StateStorageMeta(SingletonMeta, ABCMeta):
84
+ pass
85
+
86
+
87
+ class InMemoryStateStorage(StateStorage, InheritanceSingleton, metaclass=StateStorageMeta):
88
+ '''
89
+ In memory state storage
90
+ '''
91
+
92
+ def __init__(self, max_session=1000):
93
+ self._max_session = max_session
94
+ self._nodes = {} # {node_id: RunNode}
95
+ self._ordered_session_ids = []
96
+ self._session_nodes = {} # {session_id: [RunNode, RunNode]}
97
+
98
+ def get(self, node_id: str) -> RunNode:
99
+ return self._nodes.get(node_id)
100
+
101
+ def insert(self, node: RunNode):
102
+ if node.session_id not in self._ordered_session_ids:
103
+ self._ordered_session_ids.append(node.session_id)
104
+ self._session_nodes.update({node.session_id: []})
105
+ if node.node_id not in self._nodes:
106
+ self._nodes.update({node.node_id: node})
107
+ self._session_nodes[node.session_id].append(node)
108
+
109
+ if len(self._ordered_session_ids) > self._max_session:
110
+ oldest_session_id = self._ordered_session_ids.pop(0)
111
+ session_nodes = self._session_nodes.pop(oldest_session_id)
112
+ for node in session_nodes:
113
+ self._nodes.pop(node.node_id)
114
+ # logger.info(f"storage nodes: {self._nodes}")
115
+
116
+ def update(self, node: RunNode):
117
+ self._nodes[node.node_id] = node
118
+
119
+ def query(self, session_id: str, msg_id: str = None) -> List[RunNode]:
120
+ session_nodes = self._session_nodes.get(session_id, [])
121
+ if msg_id:
122
+ return [node for node in session_nodes if node.msg_id == msg_id]
123
+ return session_nodes
124
+
125
+
126
+ class RuntimeStateManager(InheritanceSingleton):
127
+ '''
128
+ Runtime state manager
129
+ '''
130
+
131
+ def __init__(self, storage: StateStorage = InMemoryStateStorage.instance()):
132
+ self.storage = storage
133
+
134
+ def create_node(self,
135
+ busi_type: RunNodeBusiType,
136
+ busi_id: str,
137
+ session_id: str,
138
+ node_id: str = None,
139
+ parent_node_id: str = None,
140
+ msg_id: str = None,
141
+ msg_from: str = None) -> RunNode:
142
+ '''
143
+ create node and insert to storage
144
+ '''
145
+ node_id = node_id or msg_id
146
+ node = self._find_node(node_id)
147
+ if node:
148
+ # raise Exception(f"node already exist, node_id: {node_id}")
149
+ return
150
+ if parent_node_id:
151
+ parent_node = self._find_node(parent_node_id)
152
+ if not parent_node:
153
+ logger.warning(
154
+ f"parent node not exist, parent_node_id: {parent_node_id}")
155
+ node = RunNode(node_id=node_id,
156
+ busi_type=busi_type,
157
+ busi_id=busi_id,
158
+ session_id=session_id,
159
+ msg_id=msg_id,
160
+ msg_from=msg_from,
161
+ parent_node_id=parent_node_id,
162
+ status=RunNodeStatus.INIT,
163
+ create_time=time.time())
164
+ self.storage.insert(node)
165
+ return node
166
+
167
+ def run_node(self, node_id: str):
168
+ '''
169
+ set node status to RUNNING and update to storage
170
+ '''
171
+ node = self._node_exist(node_id)
172
+ node.status = RunNodeStatus.RUNNING
173
+ node.execute_time = time.time()
174
+ self.storage.update(node)
175
+
176
+ def save_result(self,
177
+ node_id: str,
178
+ result: HandleResult):
179
+ '''
180
+ save node execute result and update to storage
181
+ '''
182
+ node = self._node_exist(node_id)
183
+ if not node.results:
184
+ node.results = []
185
+ node.results.append(result)
186
+ self.storage.update(node)
187
+
188
+ def break_node(self, node_id):
189
+ '''
190
+ set node status to BREAKED and update to storage
191
+ '''
192
+ node = self._node_exist(node_id)
193
+ node.status = RunNodeStatus.BREAKED
194
+ self.storage.update(node)
195
+
196
+ def run_succeed(self,
197
+ node_id,
198
+ result_msg=None,
199
+ results: List[HandleResult] = None):
200
+ '''
201
+ set node status to SUCCESS and update to storage
202
+ '''
203
+ node = self._node_exist(node_id)
204
+ node.status = RunNodeStatus.SUCCESS
205
+ node.result_msg = result_msg
206
+ node.end_time = time.time()
207
+ if results:
208
+ if not node.results:
209
+ node.results = []
210
+ node.results.extend(results)
211
+ self.storage.update(node)
212
+
213
+ def run_failed(self,
214
+ node_id,
215
+ result_msg=None,
216
+ results: List[HandleResult] = None):
217
+ '''
218
+ set node status to FAILED and update to storage
219
+ '''
220
+ node = self._node_exist(node_id)
221
+ node.status = RunNodeStatus.FAILED
222
+ node.result_msg = result_msg
223
+ node.end_time = time.time()
224
+ if results:
225
+ if not node.results:
226
+ node.results = []
227
+ node.results.extend(results)
228
+ self.storage.update(node)
229
+
230
+ def run_timeout(self,
231
+ node_id,
232
+ result_msg=None):
233
+ '''
234
+ set node status to TIMEOUNT and update to storage
235
+ '''
236
+ node = self._node_exist(node_id)
237
+ node.status = RunNodeStatus.TIMEOUNT
238
+ node.result_msg = result_msg
239
+ self.storage.update(node)
240
+
241
+ def get_node(self, node_id: str) -> RunNode:
242
+ '''
243
+ get node from storage
244
+ '''
245
+ return self._find_node(node_id)
246
+
247
+ def get_nodes(self, session_id: str) -> List[RunNode]:
248
+ '''
249
+ get nodes from storage
250
+ '''
251
+ return self.storage.query(session_id)
252
+
253
+ def _node_exist(self, node_id: str):
254
+ node = self._find_node(node_id)
255
+ if not node:
256
+ raise Exception(f"node not found, node_id: {node_id}")
257
+ return node
258
+
259
+ def _find_node(self, node_id: str):
260
+ return self.storage.get(node_id)
261
+
262
+ def _judge_msg_from_busi_type(self, msg_from: str) -> RunNodeBusiType:
263
+ '''
264
+ judge msg_from busi_type
265
+ '''
266
+ if is_agent_by_name(msg_from):
267
+ return RunNodeBusiType.AGENT
268
+ if is_tool_by_name(msg_from):
269
+ return RunNodeBusiType.TOOL
270
+ return RunNodeBusiType.TASK
271
+
272
+
273
+ class EventRuntimeStateManager(RuntimeStateManager):
274
+
275
+ def __init__(self, storage: StateStorage = InMemoryStateStorage.instance()):
276
+ super().__init__(storage)
277
+
278
+ def start_message_node(self, message: Message):
279
+ '''
280
+ create and start node while message handle started.
281
+ '''
282
+ run_node_busi_type = RunNodeBusiType.from_message_category(
283
+ message.category)
284
+ logger.info(
285
+ f"start message node: {message.receiver}, busi_type={run_node_busi_type}, node_id={message.id}")
286
+ if run_node_busi_type:
287
+ self.create_node(
288
+ node_id=message.id,
289
+ busi_type=run_node_busi_type,
290
+ busi_id=message.receiver,
291
+ session_id=message.session_id,
292
+ msg_id=message.id,
293
+ msg_from=message.sender)
294
+ self.run_node(message.id)
295
+
296
+ def save_message_handle_result(self, name: str, message: Message, result: Message = None):
297
+ '''
298
+ save message handle result
299
+ '''
300
+ run_node_busi_type = RunNodeBusiType.from_message_category(
301
+ message.category)
302
+ if run_node_busi_type:
303
+ if result and result.is_error():
304
+ handle_result = HandleResult(
305
+ name=name,
306
+ status=RunNodeStatus.FAILED,
307
+ result=result)
308
+ else:
309
+ handle_result = HandleResult(
310
+ name=name,
311
+ status=RunNodeStatus.SUCCESS,
312
+ result=result)
313
+ self.save_result(node_id=message.id, result=handle_result)
314
+
315
+ def end_message_node(self, message: Message):
316
+ '''
317
+ end node while message handle finished.
318
+ '''
319
+ run_node_busi_type = RunNodeBusiType.from_message_category(
320
+ message.category)
321
+ if run_node_busi_type:
322
+ node = self._node_exist(node_id=message.id)
323
+ status = RunNodeStatus.SUCCESS
324
+ if node.results:
325
+ for result in node.results:
326
+ if result.status == RunNodeStatus.FAILED:
327
+ status = RunNodeStatus.FAILED
328
+ break
329
+ if status == RunNodeStatus.FAILED:
330
+ self.run_failed(node_id=message.id)
331
+ else:
332
+ self.run_succeed(node_id=message.id)
aworld/runners/task_runner.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ import time
5
+ import uuid
6
+ from typing import Callable, Any
7
+
8
+ from pydantic import BaseModel
9
+
10
+ import aworld.tools
11
+ from aworld.config import ConfigDict
12
+ from aworld.config.conf import ToolConfig
13
+ from aworld.core.agent.swarm import Swarm
14
+ from aworld.core.common import Observation
15
+ from aworld.core.context.base import Context
16
+ from aworld.core.context.session import Session
17
+ from aworld.core.tool.base import Tool, AsyncTool
18
+ from aworld.core.task import Task, TaskResponse, Runner
19
+ from aworld.logs.util import logger
20
+ from aworld import trace
21
+
22
+
23
+ class TaskRunner(Runner):
24
+ """Task based runner api class."""
25
+ __metaclass__ = abc.ABCMeta
26
+
27
+ def __init__(self,
28
+ task: Task,
29
+ *,
30
+ agent_oriented: bool = True,
31
+ daemon_target: Callable[..., Any] = None):
32
+ """Task runner initialize.
33
+
34
+ Args:
35
+ task: Task entity to be executed.
36
+ agent_oriented: Is it an agent oriented task, default is True.
37
+ """
38
+ if task.tools is None:
39
+ task.tools = []
40
+ if task.tool_names is None:
41
+ task.tool_names = []
42
+
43
+ if agent_oriented:
44
+ if not task.agent and not task.swarm:
45
+ raise ValueError("agent and swarm all is None.")
46
+ if task.agent and task.swarm:
47
+ raise ValueError("agent and swarm choose one only.")
48
+ if task.agent:
49
+ # uniform agent
50
+ task.swarm = Swarm(task.agent)
51
+
52
+ if task.conf is None:
53
+ task.conf = dict()
54
+ if isinstance(task.conf, BaseModel):
55
+ task.conf = task.conf.model_dump()
56
+ check_input = task.conf.get("check_input", False)
57
+ if check_input and not task.input:
58
+ raise ValueError("task no input")
59
+
60
+ self.context = task.context if task.context else Context.instance()
61
+ self.task = task
62
+ self.context.set_task(task)
63
+ self.agent_oriented = agent_oriented
64
+ self.daemon_target = daemon_target
65
+ self._use_demon = False if not task.conf else task.conf.get('use_demon', False)
66
+ self._exception = None
67
+ self.start_time = time.time()
68
+ self.step_agent_counter = {}
69
+
70
+ async def pre_run(self):
71
+ task = self.task
72
+ self.swarm = task.swarm
73
+ self.input = task.input
74
+ self.outputs = task.outputs
75
+ self.name = task.name
76
+ self.conf = task.conf if task.conf else ConfigDict()
77
+ self.tools = {tool.name(): tool for tool in task.tools} if task.tools else {}
78
+ task.tool_names.extend(self.tools.keys())
79
+ # lazy load
80
+ self.tool_names = task.tool_names
81
+ self.tools_conf = task.tools_conf
82
+ if self.tools_conf is None:
83
+ self.tools_conf = {}
84
+ # mcp performs special process, use async only in the runn
85
+ self.tools_conf['mcp'] = ToolConfig(use_async=True, name='mcp')
86
+ self.endless_threshold = task.endless_threshold
87
+
88
+ # build context
89
+ if task.session_id:
90
+ session = Session(session_id=task.session_id)
91
+ else:
92
+ session = Session(session_id=uuid.uuid1().hex)
93
+ trace_id = uuid.uuid1().hex if trace.get_current_span() is None else trace.get_current_span().get_trace_id()
94
+ self.context.task_id = self.name
95
+ self.context.trace_id = trace_id
96
+ self.context.session = session
97
+ self.context.swarm = self.swarm
98
+
99
+ # init tool state by reset(), and ignore them observation
100
+ observation = None
101
+ if self.tools:
102
+ for _, tool in self.tools.items():
103
+ # use the observation and info of the last one
104
+ if isinstance(tool, Tool):
105
+ tool.context = self.context
106
+ observation, info = tool.reset()
107
+ elif isinstance(tool, AsyncTool):
108
+ observation, info = await tool.reset()
109
+ else:
110
+ logger.warning(f"Unsupported tool type: {tool}, will ignored.")
111
+
112
+ if observation:
113
+ if not observation.content:
114
+ observation.content = self.input
115
+ else:
116
+ observation = Observation(content=self.input)
117
+
118
+ self.observation = observation
119
+ if self.swarm:
120
+ self.swarm.event_driven = task.event_driven
121
+ self.swarm.reset(observation.content, context=self.context, tools=self.tool_names)
122
+
123
+ async def post_run(self):
124
+ self.context.reset()
125
+
126
+ @abc.abstractmethod
127
+ async def do_run(self, context: Context = None) -> TaskResponse:
128
+ """Task do run."""
aworld/runners/utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ from typing import List, Dict
4
+
5
+ from aworld.config import RunConfig
6
+ from aworld.core.agent.swarm import GraphBuildType
7
+ from aworld.core.common import Config
8
+
9
+ from aworld.core.task import Task, TaskResponse, Runner
10
+ from aworld.logs.util import logger
11
+ from aworld.utils.common import new_instance, snake_to_camel
12
+
13
+
14
+ async def choose_runners(tasks: List[Task]) -> List[Runner]:
15
+ """Choose the correct runner to run the task.
16
+
17
+ Args:
18
+ task: A task that contains agents, tools and datas.
19
+
20
+ Returns:
21
+ Runner instance or exception.
22
+ """
23
+ runners = []
24
+ for task in tasks:
25
+ # user custom runner class
26
+ runner_cls = task.runner_cls
27
+ if runner_cls:
28
+ return new_instance(runner_cls, task)
29
+ else:
30
+ # user runner class in the framework
31
+ if task.swarm:
32
+ task.swarm.event_driven = task.event_driven
33
+ execute_type = task.swarm.build_type
34
+ else:
35
+ execute_type = GraphBuildType.WORKFLOW.value
36
+
37
+ if task.event_driven:
38
+ runner = new_instance("aworld.runners.event_runner.TaskEventRunner", task)
39
+ else:
40
+ runner = new_instance(
41
+ f"aworld.runners.call_driven_runner.{snake_to_camel(execute_type)}Runner",
42
+ task
43
+ )
44
+ runners.append(runner)
45
+ return runners
46
+
47
+
48
+ async def execute_runner(runners: List[Runner], run_conf: RunConfig) -> Dict[str, TaskResponse]:
49
+ """Execute runner in the runtime engine.
50
+
51
+ Args:
52
+ runners: The task processing flow.
53
+ run_conf: Runtime config, can choose the special computing engine to execute the runner.
54
+ """
55
+ if not run_conf:
56
+ run_conf = RunConfig()
57
+
58
+ name = run_conf.name
59
+ if run_conf.cls:
60
+ runtime_backend = new_instance(run_conf.cls, run_conf)
61
+ else:
62
+ runtime_backend = new_instance(
63
+ f"aworld.core.runtime_engine.{snake_to_camel(name)}Runtime", run_conf)
64
+ runtime_engine = runtime_backend.build_engine()
65
+ return await runtime_engine.execute([runner.run for runner in runners])
66
+
67
+
68
+ def endless_detect(records: List[str], endless_threshold: int, root_agent_name: str):
69
+ """A very simple implementation of endless loop detection.
70
+
71
+ Args:
72
+ records: Call sequence of agent.
73
+ endless_threshold: Threshold for the number of repetitions.
74
+ root_agent_name: Name of the entrance agent.
75
+ """
76
+ if not records:
77
+ return False
78
+
79
+ threshold = endless_threshold
80
+ last_agent_name = root_agent_name
81
+ count = 1
82
+ for i in range(len(records) - 2, -1, -1):
83
+ if last_agent_name == records[i]:
84
+ count += 1
85
+ else:
86
+ last_agent_name = records[i]
87
+ count = 1
88
+
89
+ if count >= threshold:
90
+ logger.warning("detect loop, will exit the loop.")
91
+ return True
92
+
93
+ if len(records) > 6:
94
+ last_agent_name = None
95
+ # latest
96
+ for j in range(1, 3):
97
+ for i in range(len(records) - j, 0, -2):
98
+ if last_agent_name and last_agent_name == (records[i], records[i - 1]):
99
+ count += 1
100
+ elif last_agent_name is None:
101
+ last_agent_name = (records[i], records[i - 1])
102
+ count = 1
103
+ else:
104
+ last_agent_name = None
105
+ break
106
+
107
+ if count >= threshold:
108
+ logger.warning(f"detect loop: {last_agent_name}, will exit the loop.")
109
+ return True
110
+
111
+ return False