Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| from typing import List, Dict, Any, Callable | |
| from aworld.agents.llm_agent import Agent | |
| from aworld.config import RunConfig | |
| from aworld.core.common import Observation, ActionModel, Config | |
| class ParallelizableAgent(Agent): | |
| """Support for parallel agents in the swarm. | |
| The parameters of the extension function are the agent itself, which can obtain internal information of the agent. | |
| `aggregate_func` function example: | |
| >>> def agg(agent: ParallelizableAgent, res: Dict[str, List[ActionModel]]): | |
| >>> ... | |
| """ | |
| def __init__(self, | |
| conf: Config, | |
| resp_parse_func: Callable[..., Any] = None, | |
| agents: List[Agent] = [], | |
| aggregate_func: Callable[..., Any] = None, | |
| **kwargs): | |
| super().__init__(conf=conf, resp_parse_func=resp_parse_func, **kwargs) | |
| self.agents = agents | |
| # The function of aggregating the results of the parallel execution of agents. | |
| self.aggregate_func = aggregate_func | |
| async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
| from aworld.core.task import Task | |
| from aworld.runners.utils import choose_runners, execute_runner | |
| tasks = [] | |
| if self.agents: | |
| for agent in self.agents: | |
| tasks.append(Task(input=observation, agent=agent, context=self.context)) | |
| if not tasks: | |
| raise RuntimeError("no task need to run in parallelizable agent.") | |
| runners = await choose_runners(tasks) | |
| res = await execute_runner(runners, RunConfig(reuse_process=False)) | |
| if self.aggregate_func: | |
| return self.aggregate_func(self, res) | |
| results = [] | |
| for k, v in res.items(): | |
| results.append(ActionModel(agent_name=self.id(), policy_info=v.answer)) | |
| return results | |
| def finished(self) -> bool: | |
| return all([agent.finished for agent in self.agents]) | |