Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import abc | |
| from typing import AsyncGenerator, Tuple | |
| from aworld.agents.loop_llm_agent import LoopableAgent | |
| from aworld.core.agent.base import is_agent, AgentFactory | |
| from aworld.core.agent.swarm import GraphBuildType | |
| from aworld.core.common import ActionModel, Observation, TaskItem | |
| from aworld.core.event.base import Message, Constants, TopicType | |
| from aworld.logs.util import logger | |
| from aworld.runners.handler.base import DefaultHandler | |
| from aworld.runners.handler.tool import DefaultToolHandler | |
| from aworld.runners.utils import endless_detect | |
| from aworld.output.base import StepOutput | |
| class AgentHandler(DefaultHandler): | |
| __metaclass__ = abc.ABCMeta | |
| def __init__(self, runner: 'TaskEventRunner'): | |
| self.swarm = runner.swarm | |
| self.endless_threshold = runner.endless_threshold | |
| self.agent_calls = [] | |
| def name(cls): | |
| return "_agents_handler" | |
| class DefaultAgentHandler(AgentHandler): | |
| async def handle(self, message: Message) -> AsyncGenerator[Message, None]: | |
| if message.category != Constants.AGENT: | |
| if message.sender in self.swarm.agents and message.sender in AgentFactory: | |
| if self.agent_calls: | |
| if self.agent_calls[-1] != message.sender: | |
| self.agent_calls.append(message.sender) | |
| else: | |
| self.agent_calls.append(message.sender) | |
| return | |
| headers = {"context": message.context} | |
| session_id = message.session_id | |
| data = message.payload | |
| if not data: | |
| # error message, p2p | |
| yield Message( | |
| category=Constants.OUTPUT, | |
| payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}", | |
| step_num=0, | |
| data="no data to process."), | |
| sender=self.name(), | |
| session_id=session_id, | |
| headers=headers | |
| ) | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=TaskItem(msg="no data to process.", data=data, stop=True), | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.ERROR, | |
| headers=headers | |
| ) | |
| return | |
| if isinstance(data, Tuple) and isinstance(data[0], Observation): | |
| data = data[0] | |
| message.payload = data | |
| # data is Observation | |
| if isinstance(data, Observation): | |
| if not self.swarm: | |
| msg = Message( | |
| category=Constants.TASK, | |
| payload=data.content, | |
| sender=data.observer, | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| logger.info(f"agent handler send finished message: {msg}") | |
| yield msg | |
| return | |
| agent = self.swarm.agents.get(message.receiver) | |
| # agent + tool completion protocol. | |
| if agent and agent.finished and data.info.get('done'): | |
| self.swarm.cur_step += 1 | |
| if agent.id() == self.swarm.communicate_agent.id(): | |
| msg = Message( | |
| category=Constants.TASK, | |
| payload=data.content, | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| logger.info(f"agent handler send finished message: {msg}") | |
| yield msg | |
| else: | |
| msg = Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=data.content), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=self.swarm.communicate_agent.id(), | |
| headers=headers | |
| ) | |
| logger.info(f"agent handler send agent message: {msg}") | |
| yield msg | |
| else: | |
| if data.info.get('done'): | |
| agent_name = self.agent_calls[-1] | |
| async for event in self._stop_check(ActionModel(agent_name=agent_name, policy_info=data.content), | |
| message): | |
| yield event | |
| return | |
| logger.info(f"agent handler send observation message: {message}") | |
| yield message | |
| return | |
| # data is List[ActionModel] | |
| for action in data: | |
| if not isinstance(action, ActionModel): | |
| # error message, p2p | |
| yield Message( | |
| category=Constants.OUTPUT, | |
| payload=StepOutput.build_failed_output(name=f"{message.caller or self.name()}", | |
| step_num=0, | |
| data="action not a ActionModel."), | |
| sender=self.name(), | |
| session_id=session_id, | |
| headers=headers | |
| ) | |
| msg = Message( | |
| category=Constants.TASK, | |
| payload=TaskItem(msg="action not a ActionModel.", data=data, stop=True), | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.ERROR, | |
| headers=headers | |
| ) | |
| logger.info(f"agent handler send task message: {msg}") | |
| yield msg | |
| return | |
| tools = [] | |
| agents = [] | |
| for action in data: | |
| if is_agent(action): | |
| agents.append(action) | |
| else: | |
| tools.append(action) | |
| if tools: | |
| msg = Message( | |
| category=Constants.TOOL, | |
| payload=tools, | |
| sender=self.name(), | |
| session_id=session_id, | |
| receiver=DefaultToolHandler.name(), | |
| headers=headers | |
| ) | |
| logger.info(f"agent handler send tool message: {msg}") | |
| yield msg | |
| else: | |
| yield Message( | |
| category=Constants.OUTPUT, | |
| payload=StepOutput.build_finished_output(name=f"{message.caller or self.name()}", | |
| step_num=0), | |
| sender=self.name(), | |
| receiver=agents[0].tool_name, | |
| session_id=session_id, | |
| headers=headers | |
| ) | |
| for agent in agents: | |
| async for event in self._agent(agent, message): | |
| logger.info(f"agent handler send message: {event}") | |
| yield event | |
| async def _agent(self, action: ActionModel, message: Message): | |
| self.agent_calls.append(action.agent_name) | |
| agent = self.swarm.agents.get(action.agent_name) | |
| # be handoff | |
| agent_name = action.tool_name | |
| if not agent_name: | |
| async for event in self._stop_check(action, message): | |
| yield event | |
| return | |
| headers = {"context": message.context} | |
| session_id = message.session_id | |
| cur_agent = self.swarm.agents.get(agent_name) | |
| if not cur_agent or not agent: | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=TaskItem(msg=f"Can not find {agent_name} or {action.agent_name} agent in swarm.", | |
| data=action, | |
| stop=True), | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.ERROR, | |
| headers=headers | |
| ) | |
| return | |
| cur_agent._finished = False | |
| con = action.policy_info | |
| if action.params and 'content' in action.params: | |
| con = action.params['content'] | |
| observation = Observation(content=con, observer=agent.id(), from_agent_name=agent.id()) | |
| if agent.handoffs and agent_name not in agent.handoffs: | |
| if message.caller: | |
| message.receiver = message.caller | |
| message.caller = '' | |
| yield message | |
| else: | |
| yield Message(category=Constants.TASK, | |
| payload=TaskItem(msg=f"Can not handoffs {agent_name} agent ", data=observation), | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.RERUN, | |
| headers=headers) | |
| return | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=observation, | |
| caller=message.caller, | |
| sender=action.agent_name, | |
| session_id=session_id, | |
| receiver=action.tool_name, | |
| headers=headers | |
| ) | |
| async def _stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
| if GraphBuildType.WORKFLOW.value != self.swarm.build_type: | |
| async for event in self._social_stop_check(action, message): | |
| yield event | |
| else: | |
| if self.swarm.has_cycle: | |
| async for event in self._loop_sequence_stop_check(action, message): | |
| yield event | |
| else: | |
| async for event in self._sequence_stop_check(action, message): | |
| yield event | |
| async def _sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
| headers = {"context": message.context} | |
| session_id = message.session_id | |
| agent = self.swarm.agents.get(action.agent_name) | |
| ordered_agents = self.swarm.ordered_agents | |
| idx = next((i for i, x in enumerate(ordered_agents) if x == agent), -1) | |
| if idx == -1: | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action, | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.ERROR, | |
| headers=headers | |
| ) | |
| return | |
| # The last agent | |
| if idx == len(self.swarm.ordered_agents) - 1: | |
| receiver = None | |
| # agent loop | |
| if isinstance(agent, LoopableAgent): | |
| agent.cur_run_times += 1 | |
| if not agent.finished: | |
| receiver = agent.goto | |
| if receiver: | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=receiver, | |
| headers=headers | |
| ) | |
| else: | |
| logger.info(f"execute loop {self.swarm.cur_step}.") | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action.policy_info, | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| return | |
| # loop agent type | |
| if isinstance(agent, LoopableAgent): | |
| agent.cur_run_times += 1 | |
| if agent.finished: | |
| receiver = self.swarm.ordered_agents[idx + 1].id() | |
| else: | |
| receiver = agent.goto | |
| else: | |
| # means the loop finished | |
| receiver = self.swarm.ordered_agents[idx + 1].id() | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=receiver, | |
| headers=headers | |
| ) | |
| async def _loop_sequence_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
| headers = {"context": message.context} | |
| session_id = message.session_id | |
| agent = self.swarm.agents.get(action.agent_name) | |
| idx = next((i for i, x in enumerate(self.swarm.ordered_agents) if x == agent), -1) | |
| if idx == -1: | |
| # unknown agent, means something wrong | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action, | |
| sender=self.name(), | |
| session_id=session_id, | |
| topic=TopicType.ERROR, | |
| headers=headers | |
| ) | |
| return | |
| if idx == len(self.swarm.ordered_agents) - 1: | |
| # supported sequence loop | |
| if self.swarm.cur_step >= self.swarm.max_steps: | |
| receiver = None | |
| # agent loop | |
| if isinstance(agent, LoopableAgent): | |
| agent.cur_run_times += 1 | |
| if not agent.finished: | |
| receiver = agent.goto | |
| if receiver: | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=receiver, | |
| headers=headers | |
| ) | |
| else: | |
| # means the task finished | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action.policy_info, | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| else: | |
| self.swarm.cur_step += 1 | |
| logger.info(f"execute loop {self.swarm.cur_step}.") | |
| yield Message( | |
| category=Constants.TASK, | |
| payload='', | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.START, | |
| headers=headers | |
| ) | |
| return | |
| if isinstance(agent, LoopableAgent): | |
| agent.cur_run_times += 1 | |
| if agent.finished: | |
| receiver = self.swarm.ordered_agents[idx + 1].id() | |
| else: | |
| receiver = agent.goto | |
| else: | |
| # means the loop finished | |
| receiver = self.swarm.ordered_agents[idx + 1].id() | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.name(), | |
| session_id=session_id, | |
| receiver=receiver, | |
| headers=headers | |
| ) | |
| async def _social_stop_check(self, action: ActionModel, message: Message) -> AsyncGenerator[Message, None]: | |
| headers = {"context": message.context} | |
| agent = self.swarm.agents.get(action.agent_name) | |
| caller = message.caller | |
| session_id = message.session_id | |
| if endless_detect(self.agent_calls, | |
| endless_threshold=self.endless_threshold, | |
| root_agent_name=self.swarm.communicate_agent.id()): | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action.policy_info, | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| return | |
| if not caller or caller == self.swarm.communicate_agent.id(): | |
| if self.swarm.cur_step >= self.swarm.max_steps or self.swarm.finished: | |
| yield Message( | |
| category=Constants.TASK, | |
| payload=action.policy_info, | |
| sender=agent.id(), | |
| session_id=session_id, | |
| topic=TopicType.FINISHED, | |
| headers=headers | |
| ) | |
| else: | |
| self.swarm.cur_step += 1 | |
| logger.info(f"execute loop {self.swarm.cur_step}.") | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=self.swarm.communicate_agent.id(), | |
| headers=headers | |
| ) | |
| else: | |
| idx = 0 | |
| for idx, name in enumerate(self.agent_calls[::-1]): | |
| if name == agent.id(): | |
| break | |
| idx = len(self.agent_calls) - idx - 1 | |
| if idx: | |
| caller = self.agent_calls[idx - 1] | |
| yield Message( | |
| category=Constants.AGENT, | |
| payload=Observation(content=action.policy_info), | |
| sender=agent.id(), | |
| session_id=session_id, | |
| receiver=caller, | |
| headers=headers | |
| ) | |