|
|
|
|
|
from typing import List, Dict, Any, Callable |
|
|
|
from aworld.agents.llm_agent import Agent |
|
from aworld.config import RunConfig |
|
from aworld.core.common import Observation, ActionModel, Config |
|
|
|
|
|
class ParallelizableAgent(Agent): |
|
"""Support for parallel agents in the swarm. |
|
|
|
The parameters of the extension function are the agent itself, which can obtain internal information of the agent. |
|
`aggregate_func` function example: |
|
>>> def agg(agent: ParallelizableAgent, res: Dict[str, List[ActionModel]]): |
|
>>> ... |
|
""" |
|
|
|
def __init__(self, |
|
conf: Config, |
|
resp_parse_func: Callable[..., Any] = None, |
|
agents: List[Agent] = [], |
|
aggregate_func: Callable[..., Any] = None, |
|
**kwargs): |
|
super().__init__(conf=conf, resp_parse_func=resp_parse_func, **kwargs) |
|
self.agents = agents |
|
|
|
self.aggregate_func = aggregate_func |
|
|
|
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: |
|
from aworld.core.task import Task |
|
from aworld.runners.utils import choose_runners, execute_runner |
|
|
|
tasks = [] |
|
if self.agents: |
|
for agent in self.agents: |
|
tasks.append(Task(input=observation, agent=agent, context=self.context)) |
|
|
|
if not tasks: |
|
raise RuntimeError("no task need to run in parallelizable agent.") |
|
|
|
runners = await choose_runners(tasks) |
|
res = await execute_runner(runners, RunConfig(reuse_process=False)) |
|
|
|
if self.aggregate_func: |
|
return self.aggregate_func(self, res) |
|
|
|
results = [] |
|
for k, v in res.items(): |
|
results.append(ActionModel(agent_name=self.id(), policy_info=v.answer)) |
|
return results |
|
|
|
def finished(self) -> bool: |
|
return all([agent.finished for agent in self.agents]) |
|
|