Duibonduil commited on
Commit
bc5e560
·
verified ·
1 Parent(s): 8f1608c

Upload 5 files

Browse files
aworld/runners/handler/agent.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ from typing import AsyncGenerator, Tuple
5
+
6
+ from aworld.agents.loop_llm_agent import LoopableAgent
7
+ from aworld.core.agent.base import is_agent, AgentFactory
8
+ from aworld.core.agent.swarm import GraphBuildType
9
+ from aworld.core.common import ActionModel, Observation, TaskItem
10
+ from aworld.core.event.base import Message, Constants, TopicType
11
+ from aworld.logs.util import logger
12
+ from aworld.runners.handler.base import DefaultHandler
13
+ from aworld.runners.handler.tool import DefaultToolHandler
14
+ from aworld.runners.utils import endless_detect
15
+ from aworld.output.base import StepOutput
16
+
17
+
18
+ class AgentHandler(DefaultHandler):
19
+ __metaclass__ = abc.ABCMeta
20
+
21
+ def __init__(self, runner: 'TaskEventRunner'):
22
+ self.swarm = runner.swarm
23
+ self.endless_threshold = runner.endless_threshold
24
+
25
+ self.agent_calls = []
26
+
27
+ @classmethod
28
+ def name(cls):
29
+ return "_agents_handler"
30
+
31
+
32
+ class DefaultAgentHandler(AgentHandler):
33
+ async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
34
+ if message.category != Constants.AGENT:
35
+ if message.sender in self.swarm.agents and message.sender in AgentFactory:
36
+ if self.agent_calls:
37
+ if self.agent_calls[-1] != message.sender:
38
+ self.agent_calls.append(message.sender)
39
+ else:
40
+ self.agent_calls.append(message.sender)
41
+ return
42
+
43
+ headers = {"context": message.context}
44
+ session_id = message.session_id
45
+ data = message.payload
46
+ if not data:
47
+ # error message, p2p
48
+ yield Message(
49
+ category=Constants.OUTPUT,
50
+ payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
51
+ step_num=0,
52
+ data="no data to process."),
53
+ sender=self.name(),
54
+ session_id=session_id,
55
+ headers=headers
56
+ )
57
+ yield Message(
58
+ category=Constants.TASK,
59
+ payload=TaskItem(msg="no data to process.", data=data, stop=True),
60
+ sender=self.name(),
61
+ session_id=session_id,
62
+ topic=TopicType.ERROR,
63
+ headers=headers
64
+ )
65
+ return
66
+
67
+ if isinstance(data, Tuple) and isinstance(data[0], Observation):
68
+ data = data[0]
69
+ message.payload = data
70
+ # data is Observation
71
+ if isinstance(data, Observation):
72
+ if not self.swarm:
73
+ msg = Message(
74
+ category=Constants.TASK,
75
+ payload=data.content,
76
+ sender=data.observer,
77
+ session_id=session_id,
78
+ topic=TopicType.FINISHED,
79
+ headers=headers
80
+ )
81
+ logger.info(f"agent handler send finished message: {msg}")
82
+ yield msg
83
+ return
84
+
85
+ agent = self.swarm.agents.get(message.receiver)
86
+ # agent + tool completion protocol.
87
+ if agent and agent.finished and data.info.get('done'):
88
+ self.swarm.cur_step += 1
89
+ if agent.id() == self.swarm.communicate_agent.id():
90
+ msg = Message(
91
+ category=Constants.TASK,
92
+ payload=data.content,
93
+ sender=agent.id(),
94
+ session_id=session_id,
95
+ topic=TopicType.FINISHED,
96
+ headers=headers
97
+ )
98
+ logger.info(f"agent handler send finished message: {msg}")
99
+ yield msg
100
+ else:
101
+ msg = Message(
102
+ category=Constants.AGENT,
103
+ payload=Observation(content=data.content),
104
+ sender=agent.id(),
105
+ session_id=session_id,
106
+ receiver=self.swarm.communicate_agent.id(),
107
+ headers=headers
108
+ )
109
+ logger.info(f"agent handler send agent message: {msg}")
110
+ yield msg
111
+ else:
112
+ if data.info.get('done'):
113
+ agent_name = self.agent_calls[-1]
114
+ async for event in self._stop_check(ActionModel(agent_name=agent_name, policy_info=data.content),
115
+ message):
116
+ yield event
117
+ return
118
+ logger.info(f"agent handler send observation message: {message}")
119
+ yield message
120
+ return
121
+
122
+ # data is List[ActionModel]
123
+ for action in data:
124
+ if not isinstance(action, ActionModel):
125
+ # error message, p2p
126
+ yield Message(
127
+ category=Constants.OUTPUT,
128
+ payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
129
+ step_num=0,
130
+ data="action not a ActionModel."),
131
+ sender=self.name(),
132
+ session_id=session_id,
133
+ headers=headers
134
+ )
135
+ msg = Message(
136
+ category=Constants.TASK,
137
+ payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True),
138
+ sender=self.name(),
139
+ session_id=session_id,
140
+ topic=TopicType.ERROR,
141
+ headers=headers
142
+ )
143
+ logger.info(f"agent handler send task message: {msg}")
144
+ yield msg
145
+ return
146
+
147
+ tools = []
148
+ agents = []
149
+ for action in data:
150
+ if is_agent(action):
151
+ agents.append(action)
152
+ else:
153
+ tools.append(action)
154
+
155
+ if tools:
156
+ msg = Message(
157
+ category=Constants.TOOL,
158
+ payload=tools,
159
+ sender=self.name(),
160
+ session_id=session_id,
161
+ receiver=DefaultToolHandler.name(),
162
+ headers=headers
163
+ )
164
+ logger.info(f"agent handler send tool message: {msg}")
165
+ yield msg
166
+ else:
167
+ yield Message(
168
+ category=Constants.OUTPUT,
169
+ payload=StepOutput.build_finished_output(name=f"{message.caller or self.name()}",
170
+ step_num=0),
171
+ sender=self.name(),
172
+ receiver=agents[0].tool_name,
173
+ session_id=session_id,
174
+ headers=headers
175
+ )
176
+
177
+ for agent in agents:
178
+ async for event in self._agent(agent, message):
179
+ logger.info(f"agent handler send message: {event}")
180
+ yield event
181
+
182
+ async def _agent(self, action: ActionModel, message: Message):
183
+ self.agent_calls.append(action.agent_name)
184
+ agent = self.swarm.agents.get(action.agent_name)
185
+ # be handoff
186
+ agent_name = action.tool_name
187
+ if not agent_name:
188
+ async for event in self._stop_check(action, message):
189
+ yield event
190
+ return
191
+
192
+ headers = {"context": message.context}
193
+ session_id = message.session_id
194
+ cur_agent = self.swarm.agents.get(agent_name)
195
+ if not cur_agent or not agent:
196
+ yield Message(
197
+ category=Constants.TASK,
198
+ payload=TaskItem(msg=f"Can not find {agent_name} or {action.agent_name} agent in swarm.",
199
+ data=action,
200
+ stop=True),
201
+ sender=self.name(),
202
+ session_id=session_id,
203
+ topic=TopicType.ERROR,
204
+ headers=headers
205
+ )
206
+ return
207
+
208
+ cur_agent._finished = False
209
+ con = action.policy_info
210
+ if action.params and 'content' in action.params:
211
+ con = action.params['content']
212
+ observation = Observation(content=con, observer=agent.id(), from_agent_name=agent.id())
213
+
214
+ if agent.handoffs and agent_name not in agent.handoffs:
215
+ if message.caller:
216
+ message.receiver = message.caller
217
+ message.caller = ''
218
+ yield message
219
+ else:
220
+ yield Message(category=Constants.TASK,
221
+ payload=TaskItem(msg=f"Can not handoffs {agent_name} agent ", data=observation),
222
+ sender=self.name(),
223
+ session_id=session_id,
224
+ topic=TopicType.RERUN,
225
+ headers=headers)
226
+ return
227
+
228
+ yield Message(
229
+ category=Constants.AGENT,
230
+ payload=observation,
231
+ caller=message.caller,
232
+ sender=action.agent_name,
233
+ session_id=session_id,
234
+ receiver=action.tool_name,
235
+ headers=headers
236
+ )
237
+
238
+ async def _stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
239
+ if GraphBuildType.WORKFLOW.value != self.swarm.build_type:
240
+ async for event in self._social_stop_check(action, message):
241
+ yield event
242
+ else:
243
+ if self.swarm.has_cycle:
244
+ async for event in self._loop_sequence_stop_check(action, message):
245
+ yield event
246
+ else:
247
+ async for event in self._sequence_stop_check(action, message):
248
+ yield event
249
+
250
+ async def _sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
251
+ headers = {"context": message.context}
252
+ session_id = message.session_id
253
+ agent = self.swarm.agents.get(action.agent_name)
254
+ ordered_agents = self.swarm.ordered_agents
255
+ idx = next((i for i, x in enumerate(ordered_agents) if x == agent), -1)
256
+ if idx == -1:
257
+ yield Message(
258
+ category=Constants.TASK,
259
+ payload=action,
260
+ sender=self.name(),
261
+ session_id=session_id,
262
+ topic=TopicType.ERROR,
263
+ headers=headers
264
+ )
265
+ return
266
+
267
+ # The last agent
268
+ if idx == len(self.swarm.ordered_agents) - 1:
269
+ receiver = None
270
+ # agent loop
271
+ if isinstance(agent, LoopableAgent):
272
+ agent.cur_run_times += 1
273
+ if not agent.finished:
274
+ receiver = agent.goto
275
+
276
+ if receiver:
277
+ yield Message(
278
+ category=Constants.AGENT,
279
+ payload=Observation(content=action.policy_info),
280
+ sender=agent.id(),
281
+ session_id=session_id,
282
+ receiver=receiver,
283
+ headers=headers
284
+ )
285
+ else:
286
+ logger.info(f"execute loop {self.swarm.cur_step}.")
287
+ yield Message(
288
+ category=Constants.TASK,
289
+ payload=action.policy_info,
290
+ sender=agent.id(),
291
+ session_id=session_id,
292
+ topic=TopicType.FINISHED,
293
+ headers=headers
294
+ )
295
+ return
296
+
297
+ # loop agent type
298
+ if isinstance(agent, LoopableAgent):
299
+ agent.cur_run_times += 1
300
+ if agent.finished:
301
+ receiver = self.swarm.ordered_agents[idx + 1].id()
302
+ else:
303
+ receiver = agent.goto
304
+ else:
305
+ # means the loop finished
306
+ receiver = self.swarm.ordered_agents[idx + 1].id()
307
+ yield Message(
308
+ category=Constants.AGENT,
309
+ payload=Observation(content=action.policy_info),
310
+ sender=agent.id(),
311
+ session_id=session_id,
312
+ receiver=receiver,
313
+ headers=headers
314
+ )
315
+
316
+ async def _loop_sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
317
+ headers = {"context": message.context}
318
+ session_id = message.session_id
319
+ agent = self.swarm.agents.get(action.agent_name)
320
+ idx = next((i for i, x in enumerate(self.swarm.ordered_agents) if x == agent), -1)
321
+ if idx == -1:
322
+ # unknown agent, means something wrong
323
+ yield Message(
324
+ category=Constants.TASK,
325
+ payload=action,
326
+ sender=self.name(),
327
+ session_id=session_id,
328
+ topic=TopicType.ERROR,
329
+ headers=headers
330
+ )
331
+ return
332
+ if idx == len(self.swarm.ordered_agents) - 1:
333
+ # supported sequence loop
334
+ if self.swarm.cur_step >= self.swarm.max_steps:
335
+ receiver = None
336
+ # agent loop
337
+ if isinstance(agent, LoopableAgent):
338
+ agent.cur_run_times += 1
339
+ if not agent.finished:
340
+ receiver = agent.goto
341
+
342
+ if receiver:
343
+ yield Message(
344
+ category=Constants.AGENT,
345
+ payload=Observation(content=action.policy_info),
346
+ sender=agent.id(),
347
+ session_id=session_id,
348
+ receiver=receiver,
349
+ headers=headers
350
+ )
351
+ else:
352
+ # means the task finished
353
+ yield Message(
354
+ category=Constants.TASK,
355
+ payload=action.policy_info,
356
+ sender=agent.id(),
357
+ session_id=session_id,
358
+ topic=TopicType.FINISHED,
359
+ headers=headers
360
+ )
361
+ else:
362
+ self.swarm.cur_step += 1
363
+ logger.info(f"execute loop {self.swarm.cur_step}.")
364
+ yield Message(
365
+ category=Constants.TASK,
366
+ payload='',
367
+ sender=agent.id(),
368
+ session_id=session_id,
369
+ topic=TopicType.START,
370
+ headers=headers
371
+ )
372
+ return
373
+
374
+ if isinstance(agent, LoopableAgent):
375
+ agent.cur_run_times += 1
376
+ if agent.finished:
377
+ receiver = self.swarm.ordered_agents[idx + 1].id()
378
+ else:
379
+ receiver = agent.goto
380
+ else:
381
+ # means the loop finished
382
+ receiver = self.swarm.ordered_agents[idx + 1].id()
383
+ yield Message(
384
+ category=Constants.AGENT,
385
+ payload=Observation(content=action.policy_info),
386
+ sender=agent.name(),
387
+ session_id=session_id,
388
+ receiver=receiver,
389
+ headers=headers
390
+ )
391
+
392
+ async def _social_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
393
+ headers = {"context": message.context}
394
+ agent = self.swarm.agents.get(action.agent_name)
395
+ caller = message.caller
396
+ session_id = message.session_id
397
+ if endless_detect(self.agent_calls,
398
+ endless_threshold=self.endless_threshold,
399
+ root_agent_name=self.swarm.communicate_agent.id()):
400
+ yield Message(
401
+ category=Constants.TASK,
402
+ payload=action.policy_info,
403
+ sender=agent.id(),
404
+ session_id=session_id,
405
+ topic=TopicType.FINISHED,
406
+ headers=headers
407
+ )
408
+ return
409
+
410
+ if not caller or caller == self.swarm.communicate_agent.id():
411
+ if self.swarm.cur_step >= self.swarm.max_steps or self.swarm.finished:
412
+ yield Message(
413
+ category=Constants.TASK,
414
+ payload=action.policy_info,
415
+ sender=agent.id(),
416
+ session_id=session_id,
417
+ topic=TopicType.FINISHED,
418
+ headers=headers
419
+ )
420
+ else:
421
+ self.swarm.cur_step += 1
422
+ logger.info(f"execute loop {self.swarm.cur_step}.")
423
+ yield Message(
424
+ category=Constants.AGENT,
425
+ payload=Observation(content=action.policy_info),
426
+ sender=agent.id(),
427
+ session_id=session_id,
428
+ receiver=self.swarm.communicate_agent.id(),
429
+ headers=headers
430
+ )
431
+ else:
432
+ idx = 0
433
+ for idx, name in enumerate(self.agent_calls[::-1]):
434
+ if name == agent.id():
435
+ break
436
+ idx = len(self.agent_calls) - idx - 1
437
+ if idx:
438
+ caller = self.agent_calls[idx - 1]
439
+
440
+ yield Message(
441
+ category=Constants.AGENT,
442
+ payload=Observation(content=action.policy_info),
443
+ sender=agent.id(),
444
+ session_id=session_id,
445
+ receiver=caller,
446
+ headers=headers
447
+ )
aworld/runners/handler/base.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+
5
+ from typing import TypeVar, Generic, AsyncGenerator
6
+
7
+ from aworld.core.event.base import Message
8
+
9
+ IN = TypeVar('IN')
10
+ OUT = TypeVar('OUT')
11
+
12
+
13
+ class Handler(Generic[IN, OUT]):
14
+ __metaclass__ = abc.ABCMeta
15
+
16
+ @abc.abstractmethod
17
+ async def handle(self, data: IN) -> AsyncGenerator[OUT, None]:
18
+ """Process the data as the expected result.
19
+
20
+ Args:
21
+ data: Data generated while running the task.
22
+ """
23
+
24
+ @classmethod
25
+ def name(cls):
26
+ """Handler name."""
27
+ return cls.__name__
28
+
29
+
30
+ class DefaultHandler(Handler[Message, AsyncGenerator[Message, None]]):
31
+ """Default handler."""
aworld/runners/handler/output.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # aworld/runners/handler/output.py
2
+ import json
3
+ from typing import AsyncGenerator
4
+ from aworld.core.task import TaskResponse
5
+ from aworld.models.model_response import ModelResponse
6
+ from aworld.runners.handler.base import DefaultHandler
7
+ from aworld.output.base import StepOutput, MessageOutput, ToolResultOutput, Output
8
+ from aworld.core.common import TaskItem
9
+ from aworld.core.context.base import Context
10
+ from aworld.core.event.base import Message, Constants, TopicType
11
+ from aworld.logs.util import logger
12
+
13
+
14
+ class DefaultOutputHandler(DefaultHandler):
15
+ def __init__(self, runner):
16
+ self.runner = runner
17
+
18
+ async def handle(self, message):
19
+ if message.category != Constants.OUTPUT:
20
+ return
21
+ # 1. get outputs
22
+ outputs = self.runner.task.outputs
23
+ if not outputs:
24
+ yield Message(
25
+ category=Constants.TASK,
26
+ payload=TaskItem(msg="Cannot get outputs.", data=message, stop=True),
27
+ sender=self.name(),
28
+ session_id=Context.instance().session_id,
29
+ topic=TopicType.ERROR,
30
+ headers={"context": message.context}
31
+ )
32
+ return
33
+ # 2. build Output
34
+ payload = message.payload
35
+ mark_complete = False
36
+ output = None
37
+ try:
38
+ if isinstance(payload, Output):
39
+ output = payload
40
+ elif isinstance(payload, TaskResponse):
41
+ logger.info(f"output get task_response with usage: {json.dumps(payload.usage)}")
42
+ if message.topic == TopicType.FINISHED or message.topic == TopicType.ERROR:
43
+ mark_complete = True
44
+ elif isinstance(payload, ModelResponse) or isinstance(payload, AsyncGenerator):
45
+ output = MessageOutput(source=payload)
46
+ except Exception as e:
47
+ logger.warning(f"Failed to parse output: {e}")
48
+ yield Message(
49
+ category=Constants.TASK,
50
+ payload=TaskItem(msg="Failed to parse output.", data=payload, stop=True),
51
+ sender=self.name(),
52
+ session_id=Context.instance().session_id,
53
+ topic=TopicType.ERROR,
54
+ headers={"context": message.context}
55
+ )
56
+ finally:
57
+ if output:
58
+ if not output.metadata:
59
+ output.metadata = {}
60
+ output.metadata['sender'] = message.sender
61
+ output.metadata['receiver'] = message.receiver
62
+ await outputs.add_output(output)
63
+ if mark_complete:
64
+ await outputs.mark_completed()
65
+
66
+ return
aworld/runners/handler/task.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ import time
5
+
6
+ from typing import AsyncGenerator
7
+
8
+ from aworld.core.common import TaskItem
9
+ from aworld.core.tool.base import Tool, AsyncTool
10
+
11
+ from aworld.core.event.base import Message, Constants, TopicType
12
+ from aworld.core.task import TaskResponse
13
+ from aworld.logs.util import logger
14
+ from aworld.output import Output
15
+ from aworld.runners.handler.base import DefaultHandler
16
+ from aworld.runners.hook.hook_factory import HookFactory
17
+ from aworld.runners.hook.hooks import HookPoint
18
+
19
+
20
+ class TaskHandler(DefaultHandler):
21
+ __metaclass__ = abc.ABCMeta
22
+
23
+ def __init__(self, runner: 'TaskEventRunner'):
24
+ self.runner = runner
25
+ self.retry_count = 0
26
+ self.hooks = {}
27
+ if runner.task.hooks:
28
+ for k, vals in runner.task.hooks.items():
29
+ self.hooks[k] = []
30
+ for v in vals:
31
+ cls = HookFactory.get_class(v)
32
+ if cls:
33
+ self.hooks[k].append(cls)
34
+
35
+ @classmethod
36
+ def name(cls):
37
+ return "_task_handler"
38
+
39
+
40
+ class DefaultTaskHandler(TaskHandler):
41
+ async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
42
+ if message.category != Constants.TASK:
43
+ return
44
+
45
+ logger.info(f"task handler receive message: {message}")
46
+
47
+ headers = {"context": message.context}
48
+ topic = message.topic
49
+ task_item: TaskItem = message.payload
50
+ if topic == TopicType.SUBSCRIBE_TOOL:
51
+ new_tools = message.payload.data
52
+ for name, tool in new_tools.items():
53
+ if isinstance(tool, Tool) or isinstance(tool, AsyncTool):
54
+ await self.runner.event_mng.register(Constants.TOOL, name, tool.step)
55
+ logger.info(f"dynamic register {name} tool.")
56
+ else:
57
+ logger.warning(f"Unknown tool instance: {tool}")
58
+ return
59
+ elif topic == TopicType.SUBSCRIBE_AGENT:
60
+ return
61
+ elif topic == TopicType.ERROR:
62
+ async for event in self.run_hooks(message, HookPoint.ERROR):
63
+ yield event
64
+
65
+ if task_item.stop:
66
+ await self.runner.stop()
67
+ logger.warning(f"task {self.runner.task.id} stop, cause: {task_item.msg}")
68
+ self.runner._task_response = TaskResponse(msg=task_item.msg,
69
+ answer='',
70
+ success=False,
71
+ id=self.runner.task.id,
72
+ time_cost=(time.time() - self.runner.start_time),
73
+ usage=self.runner.context.token_usage)
74
+ return
75
+ # restart
76
+ logger.warning(f"The task {self.runner.task.id} will be restarted due to error: {task_item.msg}.")
77
+ if self.retry_count >= 3:
78
+ raise Exception(f"The task {self.runner.task.id} failed, due to error: {task_item.msg}.")
79
+
80
+ self.retry_count += 1
81
+ yield Message(
82
+ category=Constants.TASK,
83
+ payload='',
84
+ sender=self.name(),
85
+ session_id=self.runner.context.session_id,
86
+ topic=TopicType.START,
87
+ headers=headers
88
+ )
89
+ elif topic == TopicType.FINISHED:
90
+ async for event in self.run_hooks(message, HookPoint.FINISHED):
91
+ yield event
92
+
93
+ self.runner._task_response = TaskResponse(answer=str(message.payload),
94
+ success=True,
95
+ id=self.runner.task.id,
96
+ time_cost=(time.time() - self.runner.start_time),
97
+ usage=self.runner.context.token_usage)
98
+ await self.runner.stop()
99
+
100
+ logger.info(f"{self.runner.task.id} finished.")
101
+ elif topic == TopicType.START:
102
+ async for event in self.run_hooks(message, HookPoint.START):
103
+ yield event
104
+
105
+ logger.info(f"task start event: {message}, will send init message.")
106
+ if message.payload:
107
+ yield message
108
+ else:
109
+ yield self.runner.init_message
110
+ elif topic == TopicType.OUTPUT:
111
+ yield message
112
+ elif topic == TopicType.HUMAN_CONFIRM:
113
+ logger.warn("=============== Get human confirm, pause execution ===============")
114
+ if self.runner.task.outputs and message.payload:
115
+ await self.runner.task.outputs.add_output(Output(data=message.payload))
116
+ self.runner._task_response = TaskResponse(answer=str(message.payload),
117
+ success=True,
118
+ id=self.runner.task.id,
119
+ time_cost=(time.time() - self.runner.start_time),
120
+ usage=self.runner.context.token_usage)
121
+ await self.runner.stop()
122
+
123
+ async def run_hooks(self, message: Message, hook_point: str) -> AsyncGenerator[Message, None]:
124
+ hooks = self.hooks.get(hook_point, [])
125
+ for hook in hooks:
126
+ try:
127
+ msg = hook(message)
128
+ if msg:
129
+ yield msg
130
+ except:
131
+ logger.warning(f"{hook.point()} {hook.name()} execute fail.")
aworld/runners/handler/tool.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ from typing import AsyncGenerator
5
+
6
+ from aworld.core.agent.base import is_agent
7
+ from aworld.core.common import ActionModel, TaskItem
8
+ from aworld.core.event.base import Message, Constants, TopicType
9
+ from aworld.core.tool.base import AsyncTool, Tool, ToolFactory
10
+ from aworld.logs.util import logger
11
+ from aworld.runners.handler.base import DefaultHandler
12
+
13
+
14
+ class ToolHandler(DefaultHandler):
15
+ __metaclass__ = abc.ABCMeta
16
+
17
+ def __init__(self, runner: 'TaskEventRunner'):
18
+ self.tools = runner.tools
19
+ self.tools_conf = runner.tools_conf
20
+
21
+ @classmethod
22
+ def name(cls):
23
+ return "_tool_handler"
24
+
25
+
26
+ class DefaultToolHandler(ToolHandler):
27
+ async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
28
+ if message.category != Constants.TOOL:
29
+ return
30
+
31
+ headers = {"context": message.context}
32
+ # data is List[ActionModel]
33
+ data = message.payload
34
+ if not data:
35
+ # error message, p2p
36
+ yield Message(
37
+ category=Constants.TASK,
38
+ payload=TaskItem(msg="no data to process.", data=data, stop=True),
39
+ sender='agent_handler',
40
+ session_id=message.session_id,
41
+ topic=TopicType.ERROR,
42
+ headers=headers
43
+ )
44
+ return
45
+
46
+ for action in data:
47
+ if not isinstance(action, ActionModel):
48
+ # error message, p2p
49
+ yield Message(
50
+ category=Constants.TASK,
51
+ payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True),
52
+ sender=self.name(),
53
+ session_id=message.session_id,
54
+ topic=TopicType.ERROR,
55
+ headers=headers
56
+ )
57
+ return
58
+
59
+ new_tools = dict()
60
+ tool_mapping = dict()
61
+ # Directly use or use tools after creation.
62
+ for act in data:
63
+ if is_agent(act):
64
+ logger.warning(f"somethings wrong, {act} is an agent.")
65
+ continue
66
+
67
+ if not self.tools or (self.tools and act.tool_name not in self.tools):
68
+ # dynamic only use default config in module.
69
+ conf = self.tools_conf.get(act.tool_name)
70
+ tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False)
71
+ tool.event_driven = True
72
+ if isinstance(tool, Tool):
73
+ tool.reset()
74
+ elif isinstance(tool, AsyncTool):
75
+ await tool.reset()
76
+ tool_mapping[act.tool_name] = []
77
+ self.tools[act.tool_name] = tool
78
+ new_tools[act.tool_name] = tool
79
+ if act.tool_name not in tool_mapping:
80
+ tool_mapping[act.tool_name] = []
81
+ tool_mapping[act.tool_name].append(act)
82
+
83
+ if new_tools:
84
+ yield Message(
85
+ category=Constants.TASK,
86
+ payload=TaskItem(data=new_tools),
87
+ sender=self.name(),
88
+ session_id=message.session_id,
89
+ topic=TopicType.SUBSCRIBE_TOOL,
90
+ headers=headers
91
+ )
92
+
93
+ for tool_name, actions in tool_mapping.items():
94
+ if not (isinstance(self.tools[tool_name], Tool) or isinstance(self.tools[tool_name], AsyncTool)):
95
+ logger.warning(f"Unsupported tool type: {self.tools[tool_name]}")
96
+ continue
97
+
98
+ # send to the tool
99
+ yield Message(
100
+ category=Constants.TOOL,
101
+ payload=actions,
102
+ sender=actions[0].agent_name if actions else '',
103
+ session_id=message.session_id,
104
+ receiver=tool_name,
105
+ headers=headers
106
+ )