Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/runners/handler/agent.py +447 -0
- aworld/runners/handler/base.py +31 -0
- aworld/runners/handler/output.py +66 -0
- aworld/runners/handler/task.py +131 -0
- aworld/runners/handler/tool.py +106 -0
aworld/runners/handler/agent.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
from typing import AsyncGenerator, Tuple
|
5 |
+
|
6 |
+
from aworld.agents.loop_llm_agent import LoopableAgent
|
7 |
+
from aworld.core.agent.base import is_agent, AgentFactory
|
8 |
+
from aworld.core.agent.swarm import GraphBuildType
|
9 |
+
from aworld.core.common import ActionModel, Observation, TaskItem
|
10 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
11 |
+
from aworld.logs.util import logger
|
12 |
+
from aworld.runners.handler.base import DefaultHandler
|
13 |
+
from aworld.runners.handler.tool import DefaultToolHandler
|
14 |
+
from aworld.runners.utils import endless_detect
|
15 |
+
from aworld.output.base import StepOutput
|
16 |
+
|
17 |
+
|
18 |
+
class AgentHandler(DefaultHandler):
|
19 |
+
__metaclass__ = abc.ABCMeta
|
20 |
+
|
21 |
+
def __init__(self, runner: 'TaskEventRunner'):
|
22 |
+
self.swarm = runner.swarm
|
23 |
+
self.endless_threshold = runner.endless_threshold
|
24 |
+
|
25 |
+
self.agent_calls = []
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def name(cls):
|
29 |
+
return "_agents_handler"
|
30 |
+
|
31 |
+
|
32 |
+
class DefaultAgentHandler(AgentHandler):
|
33 |
+
async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
|
34 |
+
if message.category != Constants.AGENT:
|
35 |
+
if message.sender in self.swarm.agents and message.sender in AgentFactory:
|
36 |
+
if self.agent_calls:
|
37 |
+
if self.agent_calls[-1] != message.sender:
|
38 |
+
self.agent_calls.append(message.sender)
|
39 |
+
else:
|
40 |
+
self.agent_calls.append(message.sender)
|
41 |
+
return
|
42 |
+
|
43 |
+
headers = {"context": message.context}
|
44 |
+
session_id = message.session_id
|
45 |
+
data = message.payload
|
46 |
+
if not data:
|
47 |
+
# error message, p2p
|
48 |
+
yield Message(
|
49 |
+
category=Constants.OUTPUT,
|
50 |
+
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
|
51 |
+
step_num=0,
|
52 |
+
data="no data to process."),
|
53 |
+
sender=self.name(),
|
54 |
+
session_id=session_id,
|
55 |
+
headers=headers
|
56 |
+
)
|
57 |
+
yield Message(
|
58 |
+
category=Constants.TASK,
|
59 |
+
payload=TaskItem(msg="no data to process.", data=data, stop=True),
|
60 |
+
sender=self.name(),
|
61 |
+
session_id=session_id,
|
62 |
+
topic=TopicType.ERROR,
|
63 |
+
headers=headers
|
64 |
+
)
|
65 |
+
return
|
66 |
+
|
67 |
+
if isinstance(data, Tuple) and isinstance(data[0], Observation):
|
68 |
+
data = data[0]
|
69 |
+
message.payload = data
|
70 |
+
# data is Observation
|
71 |
+
if isinstance(data, Observation):
|
72 |
+
if not self.swarm:
|
73 |
+
msg = Message(
|
74 |
+
category=Constants.TASK,
|
75 |
+
payload=data.content,
|
76 |
+
sender=data.observer,
|
77 |
+
session_id=session_id,
|
78 |
+
topic=TopicType.FINISHED,
|
79 |
+
headers=headers
|
80 |
+
)
|
81 |
+
logger.info(f"agent handler send finished message: {msg}")
|
82 |
+
yield msg
|
83 |
+
return
|
84 |
+
|
85 |
+
agent = self.swarm.agents.get(message.receiver)
|
86 |
+
# agent + tool completion protocol.
|
87 |
+
if agent and agent.finished and data.info.get('done'):
|
88 |
+
self.swarm.cur_step += 1
|
89 |
+
if agent.id() == self.swarm.communicate_agent.id():
|
90 |
+
msg = Message(
|
91 |
+
category=Constants.TASK,
|
92 |
+
payload=data.content,
|
93 |
+
sender=agent.id(),
|
94 |
+
session_id=session_id,
|
95 |
+
topic=TopicType.FINISHED,
|
96 |
+
headers=headers
|
97 |
+
)
|
98 |
+
logger.info(f"agent handler send finished message: {msg}")
|
99 |
+
yield msg
|
100 |
+
else:
|
101 |
+
msg = Message(
|
102 |
+
category=Constants.AGENT,
|
103 |
+
payload=Observation(content=data.content),
|
104 |
+
sender=agent.id(),
|
105 |
+
session_id=session_id,
|
106 |
+
receiver=self.swarm.communicate_agent.id(),
|
107 |
+
headers=headers
|
108 |
+
)
|
109 |
+
logger.info(f"agent handler send agent message: {msg}")
|
110 |
+
yield msg
|
111 |
+
else:
|
112 |
+
if data.info.get('done'):
|
113 |
+
agent_name = self.agent_calls[-1]
|
114 |
+
async for event in self._stop_check(ActionModel(agent_name=agent_name, policy_info=data.content),
|
115 |
+
message):
|
116 |
+
yield event
|
117 |
+
return
|
118 |
+
logger.info(f"agent handler send observation message: {message}")
|
119 |
+
yield message
|
120 |
+
return
|
121 |
+
|
122 |
+
# data is List[ActionModel]
|
123 |
+
for action in data:
|
124 |
+
if not isinstance(action, ActionModel):
|
125 |
+
# error message, p2p
|
126 |
+
yield Message(
|
127 |
+
category=Constants.OUTPUT,
|
128 |
+
payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}",
|
129 |
+
step_num=0,
|
130 |
+
data="action not a ActionModel."),
|
131 |
+
sender=self.name(),
|
132 |
+
session_id=session_id,
|
133 |
+
headers=headers
|
134 |
+
)
|
135 |
+
msg = Message(
|
136 |
+
category=Constants.TASK,
|
137 |
+
payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True),
|
138 |
+
sender=self.name(),
|
139 |
+
session_id=session_id,
|
140 |
+
topic=TopicType.ERROR,
|
141 |
+
headers=headers
|
142 |
+
)
|
143 |
+
logger.info(f"agent handler send task message: {msg}")
|
144 |
+
yield msg
|
145 |
+
return
|
146 |
+
|
147 |
+
tools = []
|
148 |
+
agents = []
|
149 |
+
for action in data:
|
150 |
+
if is_agent(action):
|
151 |
+
agents.append(action)
|
152 |
+
else:
|
153 |
+
tools.append(action)
|
154 |
+
|
155 |
+
if tools:
|
156 |
+
msg = Message(
|
157 |
+
category=Constants.TOOL,
|
158 |
+
payload=tools,
|
159 |
+
sender=self.name(),
|
160 |
+
session_id=session_id,
|
161 |
+
receiver=DefaultToolHandler.name(),
|
162 |
+
headers=headers
|
163 |
+
)
|
164 |
+
logger.info(f"agent handler send tool message: {msg}")
|
165 |
+
yield msg
|
166 |
+
else:
|
167 |
+
yield Message(
|
168 |
+
category=Constants.OUTPUT,
|
169 |
+
payload=StepOutput.build_finished_output(name=f"{message.caller or self.name()}",
|
170 |
+
step_num=0),
|
171 |
+
sender=self.name(),
|
172 |
+
receiver=agents[0].tool_name,
|
173 |
+
session_id=session_id,
|
174 |
+
headers=headers
|
175 |
+
)
|
176 |
+
|
177 |
+
for agent in agents:
|
178 |
+
async for event in self._agent(agent, message):
|
179 |
+
logger.info(f"agent handler send message: {event}")
|
180 |
+
yield event
|
181 |
+
|
182 |
+
async def _agent(self, action: ActionModel, message: Message):
|
183 |
+
self.agent_calls.append(action.agent_name)
|
184 |
+
agent = self.swarm.agents.get(action.agent_name)
|
185 |
+
# be handoff
|
186 |
+
agent_name = action.tool_name
|
187 |
+
if not agent_name:
|
188 |
+
async for event in self._stop_check(action, message):
|
189 |
+
yield event
|
190 |
+
return
|
191 |
+
|
192 |
+
headers = {"context": message.context}
|
193 |
+
session_id = message.session_id
|
194 |
+
cur_agent = self.swarm.agents.get(agent_name)
|
195 |
+
if not cur_agent or not agent:
|
196 |
+
yield Message(
|
197 |
+
category=Constants.TASK,
|
198 |
+
payload=TaskItem(msg=f"Can not find {agent_name} or {action.agent_name} agent in swarm.",
|
199 |
+
data=action,
|
200 |
+
stop=True),
|
201 |
+
sender=self.name(),
|
202 |
+
session_id=session_id,
|
203 |
+
topic=TopicType.ERROR,
|
204 |
+
headers=headers
|
205 |
+
)
|
206 |
+
return
|
207 |
+
|
208 |
+
cur_agent._finished = False
|
209 |
+
con = action.policy_info
|
210 |
+
if action.params and 'content' in action.params:
|
211 |
+
con = action.params['content']
|
212 |
+
observation = Observation(content=con, observer=agent.id(), from_agent_name=agent.id())
|
213 |
+
|
214 |
+
if agent.handoffs and agent_name not in agent.handoffs:
|
215 |
+
if message.caller:
|
216 |
+
message.receiver = message.caller
|
217 |
+
message.caller = ''
|
218 |
+
yield message
|
219 |
+
else:
|
220 |
+
yield Message(category=Constants.TASK,
|
221 |
+
payload=TaskItem(msg=f"Can not handoffs {agent_name} agent ", data=observation),
|
222 |
+
sender=self.name(),
|
223 |
+
session_id=session_id,
|
224 |
+
topic=TopicType.RERUN,
|
225 |
+
headers=headers)
|
226 |
+
return
|
227 |
+
|
228 |
+
yield Message(
|
229 |
+
category=Constants.AGENT,
|
230 |
+
payload=observation,
|
231 |
+
caller=message.caller,
|
232 |
+
sender=action.agent_name,
|
233 |
+
session_id=session_id,
|
234 |
+
receiver=action.tool_name,
|
235 |
+
headers=headers
|
236 |
+
)
|
237 |
+
|
238 |
+
async def _stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
|
239 |
+
if GraphBuildType.WORKFLOW.value != self.swarm.build_type:
|
240 |
+
async for event in self._social_stop_check(action, message):
|
241 |
+
yield event
|
242 |
+
else:
|
243 |
+
if self.swarm.has_cycle:
|
244 |
+
async for event in self._loop_sequence_stop_check(action, message):
|
245 |
+
yield event
|
246 |
+
else:
|
247 |
+
async for event in self._sequence_stop_check(action, message):
|
248 |
+
yield event
|
249 |
+
|
250 |
+
async def _sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
|
251 |
+
headers = {"context": message.context}
|
252 |
+
session_id = message.session_id
|
253 |
+
agent = self.swarm.agents.get(action.agent_name)
|
254 |
+
ordered_agents = self.swarm.ordered_agents
|
255 |
+
idx = next((i for i, x in enumerate(ordered_agents) if x == agent), -1)
|
256 |
+
if idx == -1:
|
257 |
+
yield Message(
|
258 |
+
category=Constants.TASK,
|
259 |
+
payload=action,
|
260 |
+
sender=self.name(),
|
261 |
+
session_id=session_id,
|
262 |
+
topic=TopicType.ERROR,
|
263 |
+
headers=headers
|
264 |
+
)
|
265 |
+
return
|
266 |
+
|
267 |
+
# The last agent
|
268 |
+
if idx == len(self.swarm.ordered_agents) - 1:
|
269 |
+
receiver = None
|
270 |
+
# agent loop
|
271 |
+
if isinstance(agent, LoopableAgent):
|
272 |
+
agent.cur_run_times += 1
|
273 |
+
if not agent.finished:
|
274 |
+
receiver = agent.goto
|
275 |
+
|
276 |
+
if receiver:
|
277 |
+
yield Message(
|
278 |
+
category=Constants.AGENT,
|
279 |
+
payload=Observation(content=action.policy_info),
|
280 |
+
sender=agent.id(),
|
281 |
+
session_id=session_id,
|
282 |
+
receiver=receiver,
|
283 |
+
headers=headers
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
logger.info(f"execute loop {self.swarm.cur_step}.")
|
287 |
+
yield Message(
|
288 |
+
category=Constants.TASK,
|
289 |
+
payload=action.policy_info,
|
290 |
+
sender=agent.id(),
|
291 |
+
session_id=session_id,
|
292 |
+
topic=TopicType.FINISHED,
|
293 |
+
headers=headers
|
294 |
+
)
|
295 |
+
return
|
296 |
+
|
297 |
+
# loop agent type
|
298 |
+
if isinstance(agent, LoopableAgent):
|
299 |
+
agent.cur_run_times += 1
|
300 |
+
if agent.finished:
|
301 |
+
receiver = self.swarm.ordered_agents[idx + 1].id()
|
302 |
+
else:
|
303 |
+
receiver = agent.goto
|
304 |
+
else:
|
305 |
+
# means the loop finished
|
306 |
+
receiver = self.swarm.ordered_agents[idx + 1].id()
|
307 |
+
yield Message(
|
308 |
+
category=Constants.AGENT,
|
309 |
+
payload=Observation(content=action.policy_info),
|
310 |
+
sender=agent.id(),
|
311 |
+
session_id=session_id,
|
312 |
+
receiver=receiver,
|
313 |
+
headers=headers
|
314 |
+
)
|
315 |
+
|
316 |
+
async def _loop_sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
|
317 |
+
headers = {"context": message.context}
|
318 |
+
session_id = message.session_id
|
319 |
+
agent = self.swarm.agents.get(action.agent_name)
|
320 |
+
idx = next((i for i, x in enumerate(self.swarm.ordered_agents) if x == agent), -1)
|
321 |
+
if idx == -1:
|
322 |
+
# unknown agent, means something wrong
|
323 |
+
yield Message(
|
324 |
+
category=Constants.TASK,
|
325 |
+
payload=action,
|
326 |
+
sender=self.name(),
|
327 |
+
session_id=session_id,
|
328 |
+
topic=TopicType.ERROR,
|
329 |
+
headers=headers
|
330 |
+
)
|
331 |
+
return
|
332 |
+
if idx == len(self.swarm.ordered_agents) - 1:
|
333 |
+
# supported sequence loop
|
334 |
+
if self.swarm.cur_step >= self.swarm.max_steps:
|
335 |
+
receiver = None
|
336 |
+
# agent loop
|
337 |
+
if isinstance(agent, LoopableAgent):
|
338 |
+
agent.cur_run_times += 1
|
339 |
+
if not agent.finished:
|
340 |
+
receiver = agent.goto
|
341 |
+
|
342 |
+
if receiver:
|
343 |
+
yield Message(
|
344 |
+
category=Constants.AGENT,
|
345 |
+
payload=Observation(content=action.policy_info),
|
346 |
+
sender=agent.id(),
|
347 |
+
session_id=session_id,
|
348 |
+
receiver=receiver,
|
349 |
+
headers=headers
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
# means the task finished
|
353 |
+
yield Message(
|
354 |
+
category=Constants.TASK,
|
355 |
+
payload=action.policy_info,
|
356 |
+
sender=agent.id(),
|
357 |
+
session_id=session_id,
|
358 |
+
topic=TopicType.FINISHED,
|
359 |
+
headers=headers
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
self.swarm.cur_step += 1
|
363 |
+
logger.info(f"execute loop {self.swarm.cur_step}.")
|
364 |
+
yield Message(
|
365 |
+
category=Constants.TASK,
|
366 |
+
payload='',
|
367 |
+
sender=agent.id(),
|
368 |
+
session_id=session_id,
|
369 |
+
topic=TopicType.START,
|
370 |
+
headers=headers
|
371 |
+
)
|
372 |
+
return
|
373 |
+
|
374 |
+
if isinstance(agent, LoopableAgent):
|
375 |
+
agent.cur_run_times += 1
|
376 |
+
if agent.finished:
|
377 |
+
receiver = self.swarm.ordered_agents[idx + 1].id()
|
378 |
+
else:
|
379 |
+
receiver = agent.goto
|
380 |
+
else:
|
381 |
+
# means the loop finished
|
382 |
+
receiver = self.swarm.ordered_agents[idx + 1].id()
|
383 |
+
yield Message(
|
384 |
+
category=Constants.AGENT,
|
385 |
+
payload=Observation(content=action.policy_info),
|
386 |
+
sender=agent.name(),
|
387 |
+
session_id=session_id,
|
388 |
+
receiver=receiver,
|
389 |
+
headers=headers
|
390 |
+
)
|
391 |
+
|
392 |
+
async def _social_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]:
|
393 |
+
headers = {"context": message.context}
|
394 |
+
agent = self.swarm.agents.get(action.agent_name)
|
395 |
+
caller = message.caller
|
396 |
+
session_id = message.session_id
|
397 |
+
if endless_detect(self.agent_calls,
|
398 |
+
endless_threshold=self.endless_threshold,
|
399 |
+
root_agent_name=self.swarm.communicate_agent.id()):
|
400 |
+
yield Message(
|
401 |
+
category=Constants.TASK,
|
402 |
+
payload=action.policy_info,
|
403 |
+
sender=agent.id(),
|
404 |
+
session_id=session_id,
|
405 |
+
topic=TopicType.FINISHED,
|
406 |
+
headers=headers
|
407 |
+
)
|
408 |
+
return
|
409 |
+
|
410 |
+
if not caller or caller == self.swarm.communicate_agent.id():
|
411 |
+
if self.swarm.cur_step >= self.swarm.max_steps or self.swarm.finished:
|
412 |
+
yield Message(
|
413 |
+
category=Constants.TASK,
|
414 |
+
payload=action.policy_info,
|
415 |
+
sender=agent.id(),
|
416 |
+
session_id=session_id,
|
417 |
+
topic=TopicType.FINISHED,
|
418 |
+
headers=headers
|
419 |
+
)
|
420 |
+
else:
|
421 |
+
self.swarm.cur_step += 1
|
422 |
+
logger.info(f"execute loop {self.swarm.cur_step}.")
|
423 |
+
yield Message(
|
424 |
+
category=Constants.AGENT,
|
425 |
+
payload=Observation(content=action.policy_info),
|
426 |
+
sender=agent.id(),
|
427 |
+
session_id=session_id,
|
428 |
+
receiver=self.swarm.communicate_agent.id(),
|
429 |
+
headers=headers
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
idx = 0
|
433 |
+
for idx, name in enumerate(self.agent_calls[::-1]):
|
434 |
+
if name == agent.id():
|
435 |
+
break
|
436 |
+
idx = len(self.agent_calls) - idx - 1
|
437 |
+
if idx:
|
438 |
+
caller = self.agent_calls[idx - 1]
|
439 |
+
|
440 |
+
yield Message(
|
441 |
+
category=Constants.AGENT,
|
442 |
+
payload=Observation(content=action.policy_info),
|
443 |
+
sender=agent.id(),
|
444 |
+
session_id=session_id,
|
445 |
+
receiver=caller,
|
446 |
+
headers=headers
|
447 |
+
)
|
aworld/runners/handler/base.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
|
5 |
+
from typing import TypeVar, Generic, AsyncGenerator
|
6 |
+
|
7 |
+
from aworld.core.event.base import Message
|
8 |
+
|
9 |
+
IN = TypeVar('IN')
|
10 |
+
OUT = TypeVar('OUT')
|
11 |
+
|
12 |
+
|
13 |
+
class Handler(Generic[IN, OUT]):
|
14 |
+
__metaclass__ = abc.ABCMeta
|
15 |
+
|
16 |
+
@abc.abstractmethod
|
17 |
+
async def handle(self, data: IN) -> AsyncGenerator[OUT, None]:
|
18 |
+
"""Process the data as the expected result.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
data: Data generated while running the task.
|
22 |
+
"""
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def name(cls):
|
26 |
+
"""Handler name."""
|
27 |
+
return cls.__name__
|
28 |
+
|
29 |
+
|
30 |
+
class DefaultHandler(Handler[Message, AsyncGenerator[Message, None]]):
|
31 |
+
"""Default handler."""
|
aworld/runners/handler/output.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# aworld/runners/handler/output.py
|
2 |
+
import json
|
3 |
+
from typing import AsyncGenerator
|
4 |
+
from aworld.core.task import TaskResponse
|
5 |
+
from aworld.models.model_response import ModelResponse
|
6 |
+
from aworld.runners.handler.base import DefaultHandler
|
7 |
+
from aworld.output.base import StepOutput, MessageOutput, ToolResultOutput, Output
|
8 |
+
from aworld.core.common import TaskItem
|
9 |
+
from aworld.core.context.base import Context
|
10 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
11 |
+
from aworld.logs.util import logger
|
12 |
+
|
13 |
+
|
14 |
+
class DefaultOutputHandler(DefaultHandler):
|
15 |
+
def __init__(self, runner):
|
16 |
+
self.runner = runner
|
17 |
+
|
18 |
+
async def handle(self, message):
|
19 |
+
if message.category != Constants.OUTPUT:
|
20 |
+
return
|
21 |
+
# 1. get outputs
|
22 |
+
outputs = self.runner.task.outputs
|
23 |
+
if not outputs:
|
24 |
+
yield Message(
|
25 |
+
category=Constants.TASK,
|
26 |
+
payload=TaskItem(msg="Cannot get outputs.", data=message, stop=True),
|
27 |
+
sender=self.name(),
|
28 |
+
session_id=Context.instance().session_id,
|
29 |
+
topic=TopicType.ERROR,
|
30 |
+
headers={"context": message.context}
|
31 |
+
)
|
32 |
+
return
|
33 |
+
# 2. build Output
|
34 |
+
payload = message.payload
|
35 |
+
mark_complete = False
|
36 |
+
output = None
|
37 |
+
try:
|
38 |
+
if isinstance(payload, Output):
|
39 |
+
output = payload
|
40 |
+
elif isinstance(payload, TaskResponse):
|
41 |
+
logger.info(f"output get task_response with usage: {json.dumps(payload.usage)}")
|
42 |
+
if message.topic == TopicType.FINISHED or message.topic == TopicType.ERROR:
|
43 |
+
mark_complete = True
|
44 |
+
elif isinstance(payload, ModelResponse) or isinstance(payload, AsyncGenerator):
|
45 |
+
output = MessageOutput(source=payload)
|
46 |
+
except Exception as e:
|
47 |
+
logger.warning(f"Failed to parse output: {e}")
|
48 |
+
yield Message(
|
49 |
+
category=Constants.TASK,
|
50 |
+
payload=TaskItem(msg="Failed to parse output.", data=payload, stop=True),
|
51 |
+
sender=self.name(),
|
52 |
+
session_id=Context.instance().session_id,
|
53 |
+
topic=TopicType.ERROR,
|
54 |
+
headers={"context": message.context}
|
55 |
+
)
|
56 |
+
finally:
|
57 |
+
if output:
|
58 |
+
if not output.metadata:
|
59 |
+
output.metadata = {}
|
60 |
+
output.metadata['sender'] = message.sender
|
61 |
+
output.metadata['receiver'] = message.receiver
|
62 |
+
await outputs.add_output(output)
|
63 |
+
if mark_complete:
|
64 |
+
await outputs.mark_completed()
|
65 |
+
|
66 |
+
return
|
aworld/runners/handler/task.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
import time
|
5 |
+
|
6 |
+
from typing import AsyncGenerator
|
7 |
+
|
8 |
+
from aworld.core.common import TaskItem
|
9 |
+
from aworld.core.tool.base import Tool, AsyncTool
|
10 |
+
|
11 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
12 |
+
from aworld.core.task import TaskResponse
|
13 |
+
from aworld.logs.util import logger
|
14 |
+
from aworld.output import Output
|
15 |
+
from aworld.runners.handler.base import DefaultHandler
|
16 |
+
from aworld.runners.hook.hook_factory import HookFactory
|
17 |
+
from aworld.runners.hook.hooks import HookPoint
|
18 |
+
|
19 |
+
|
20 |
+
class TaskHandler(DefaultHandler):
|
21 |
+
__metaclass__ = abc.ABCMeta
|
22 |
+
|
23 |
+
def __init__(self, runner: 'TaskEventRunner'):
|
24 |
+
self.runner = runner
|
25 |
+
self.retry_count = 0
|
26 |
+
self.hooks = {}
|
27 |
+
if runner.task.hooks:
|
28 |
+
for k, vals in runner.task.hooks.items():
|
29 |
+
self.hooks[k] = []
|
30 |
+
for v in vals:
|
31 |
+
cls = HookFactory.get_class(v)
|
32 |
+
if cls:
|
33 |
+
self.hooks[k].append(cls)
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def name(cls):
|
37 |
+
return "_task_handler"
|
38 |
+
|
39 |
+
|
40 |
+
class DefaultTaskHandler(TaskHandler):
|
41 |
+
async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
|
42 |
+
if message.category != Constants.TASK:
|
43 |
+
return
|
44 |
+
|
45 |
+
logger.info(f"task handler receive message: {message}")
|
46 |
+
|
47 |
+
headers = {"context": message.context}
|
48 |
+
topic = message.topic
|
49 |
+
task_item: TaskItem = message.payload
|
50 |
+
if topic == TopicType.SUBSCRIBE_TOOL:
|
51 |
+
new_tools = message.payload.data
|
52 |
+
for name, tool in new_tools.items():
|
53 |
+
if isinstance(tool, Tool) or isinstance(tool, AsyncTool):
|
54 |
+
await self.runner.event_mng.register(Constants.TOOL, name, tool.step)
|
55 |
+
logger.info(f"dynamic register {name} tool.")
|
56 |
+
else:
|
57 |
+
logger.warning(f"Unknown tool instance: {tool}")
|
58 |
+
return
|
59 |
+
elif topic == TopicType.SUBSCRIBE_AGENT:
|
60 |
+
return
|
61 |
+
elif topic == TopicType.ERROR:
|
62 |
+
async for event in self.run_hooks(message, HookPoint.ERROR):
|
63 |
+
yield event
|
64 |
+
|
65 |
+
if task_item.stop:
|
66 |
+
await self.runner.stop()
|
67 |
+
logger.warning(f"task {self.runner.task.id} stop, cause: {task_item.msg}")
|
68 |
+
self.runner._task_response = TaskResponse(msg=task_item.msg,
|
69 |
+
answer='',
|
70 |
+
success=False,
|
71 |
+
id=self.runner.task.id,
|
72 |
+
time_cost=(time.time() - self.runner.start_time),
|
73 |
+
usage=self.runner.context.token_usage)
|
74 |
+
return
|
75 |
+
# restart
|
76 |
+
logger.warning(f"The task {self.runner.task.id} will be restarted due to error: {task_item.msg}.")
|
77 |
+
if self.retry_count >= 3:
|
78 |
+
raise Exception(f"The task {self.runner.task.id} failed, due to error: {task_item.msg}.")
|
79 |
+
|
80 |
+
self.retry_count += 1
|
81 |
+
yield Message(
|
82 |
+
category=Constants.TASK,
|
83 |
+
payload='',
|
84 |
+
sender=self.name(),
|
85 |
+
session_id=self.runner.context.session_id,
|
86 |
+
topic=TopicType.START,
|
87 |
+
headers=headers
|
88 |
+
)
|
89 |
+
elif topic == TopicType.FINISHED:
|
90 |
+
async for event in self.run_hooks(message, HookPoint.FINISHED):
|
91 |
+
yield event
|
92 |
+
|
93 |
+
self.runner._task_response = TaskResponse(answer=str(message.payload),
|
94 |
+
success=True,
|
95 |
+
id=self.runner.task.id,
|
96 |
+
time_cost=(time.time() - self.runner.start_time),
|
97 |
+
usage=self.runner.context.token_usage)
|
98 |
+
await self.runner.stop()
|
99 |
+
|
100 |
+
logger.info(f"{self.runner.task.id} finished.")
|
101 |
+
elif topic == TopicType.START:
|
102 |
+
async for event in self.run_hooks(message, HookPoint.START):
|
103 |
+
yield event
|
104 |
+
|
105 |
+
logger.info(f"task start event: {message}, will send init message.")
|
106 |
+
if message.payload:
|
107 |
+
yield message
|
108 |
+
else:
|
109 |
+
yield self.runner.init_message
|
110 |
+
elif topic == TopicType.OUTPUT:
|
111 |
+
yield message
|
112 |
+
elif topic == TopicType.HUMAN_CONFIRM:
|
113 |
+
logger.warn("=============== Get human confirm, pause execution ===============")
|
114 |
+
if self.runner.task.outputs and message.payload:
|
115 |
+
await self.runner.task.outputs.add_output(Output(data=message.payload))
|
116 |
+
self.runner._task_response = TaskResponse(answer=str(message.payload),
|
117 |
+
success=True,
|
118 |
+
id=self.runner.task.id,
|
119 |
+
time_cost=(time.time() - self.runner.start_time),
|
120 |
+
usage=self.runner.context.token_usage)
|
121 |
+
await self.runner.stop()
|
122 |
+
|
123 |
+
async def run_hooks(self, message: Message, hook_point: str) -> AsyncGenerator[Message, None]:
|
124 |
+
hooks = self.hooks.get(hook_point, [])
|
125 |
+
for hook in hooks:
|
126 |
+
try:
|
127 |
+
msg = hook(message)
|
128 |
+
if msg:
|
129 |
+
yield msg
|
130 |
+
except:
|
131 |
+
logger.warning(f"{hook.point()} {hook.name()} execute fail.")
|
aworld/runners/handler/tool.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
from typing import AsyncGenerator
|
5 |
+
|
6 |
+
from aworld.core.agent.base import is_agent
|
7 |
+
from aworld.core.common import ActionModel, TaskItem
|
8 |
+
from aworld.core.event.base import Message, Constants, TopicType
|
9 |
+
from aworld.core.tool.base import AsyncTool, Tool, ToolFactory
|
10 |
+
from aworld.logs.util import logger
|
11 |
+
from aworld.runners.handler.base import DefaultHandler
|
12 |
+
|
13 |
+
|
14 |
+
class ToolHandler(DefaultHandler):
|
15 |
+
__metaclass__ = abc.ABCMeta
|
16 |
+
|
17 |
+
def __init__(self, runner: 'TaskEventRunner'):
|
18 |
+
self.tools = runner.tools
|
19 |
+
self.tools_conf = runner.tools_conf
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def name(cls):
|
23 |
+
return "_tool_handler"
|
24 |
+
|
25 |
+
|
26 |
+
class DefaultToolHandler(ToolHandler):
|
27 |
+
async def handle(self, message: Message) -> AsyncGenerator[Message, None]:
|
28 |
+
if message.category != Constants.TOOL:
|
29 |
+
return
|
30 |
+
|
31 |
+
headers = {"context": message.context}
|
32 |
+
# data is List[ActionModel]
|
33 |
+
data = message.payload
|
34 |
+
if not data:
|
35 |
+
# error message, p2p
|
36 |
+
yield Message(
|
37 |
+
category=Constants.TASK,
|
38 |
+
payload=TaskItem(msg="no data to process.", data=data, stop=True),
|
39 |
+
sender='agent_handler',
|
40 |
+
session_id=message.session_id,
|
41 |
+
topic=TopicType.ERROR,
|
42 |
+
headers=headers
|
43 |
+
)
|
44 |
+
return
|
45 |
+
|
46 |
+
for action in data:
|
47 |
+
if not isinstance(action, ActionModel):
|
48 |
+
# error message, p2p
|
49 |
+
yield Message(
|
50 |
+
category=Constants.TASK,
|
51 |
+
payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True),
|
52 |
+
sender=self.name(),
|
53 |
+
session_id=message.session_id,
|
54 |
+
topic=TopicType.ERROR,
|
55 |
+
headers=headers
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
new_tools = dict()
|
60 |
+
tool_mapping = dict()
|
61 |
+
# Directly use or use tools after creation.
|
62 |
+
for act in data:
|
63 |
+
if is_agent(act):
|
64 |
+
logger.warning(f"somethings wrong, {act} is an agent.")
|
65 |
+
continue
|
66 |
+
|
67 |
+
if not self.tools or (self.tools and act.tool_name not in self.tools):
|
68 |
+
# dynamic only use default config in module.
|
69 |
+
conf = self.tools_conf.get(act.tool_name)
|
70 |
+
tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False)
|
71 |
+
tool.event_driven = True
|
72 |
+
if isinstance(tool, Tool):
|
73 |
+
tool.reset()
|
74 |
+
elif isinstance(tool, AsyncTool):
|
75 |
+
await tool.reset()
|
76 |
+
tool_mapping[act.tool_name] = []
|
77 |
+
self.tools[act.tool_name] = tool
|
78 |
+
new_tools[act.tool_name] = tool
|
79 |
+
if act.tool_name not in tool_mapping:
|
80 |
+
tool_mapping[act.tool_name] = []
|
81 |
+
tool_mapping[act.tool_name].append(act)
|
82 |
+
|
83 |
+
if new_tools:
|
84 |
+
yield Message(
|
85 |
+
category=Constants.TASK,
|
86 |
+
payload=TaskItem(data=new_tools),
|
87 |
+
sender=self.name(),
|
88 |
+
session_id=message.session_id,
|
89 |
+
topic=TopicType.SUBSCRIBE_TOOL,
|
90 |
+
headers=headers
|
91 |
+
)
|
92 |
+
|
93 |
+
for tool_name, actions in tool_mapping.items():
|
94 |
+
if not (isinstance(self.tools[tool_name], Tool) or isinstance(self.tools[tool_name], AsyncTool)):
|
95 |
+
logger.warning(f"Unsupported tool type: {self.tools[tool_name]}")
|
96 |
+
continue
|
97 |
+
|
98 |
+
# send to the tool
|
99 |
+
yield Message(
|
100 |
+
category=Constants.TOOL,
|
101 |
+
payload=actions,
|
102 |
+
sender=actions[0].agent_name if actions else '',
|
103 |
+
session_id=message.session_id,
|
104 |
+
receiver=tool_name,
|
105 |
+
headers=headers
|
106 |
+
)
|