|
|
|
|
|
import abc |
|
import json |
|
import time |
|
import traceback |
|
import uuid |
|
from collections import OrderedDict |
|
from typing import AsyncGenerator, Dict, Any, List, Union, Callable |
|
|
|
import aworld.trace as trace |
|
from aworld.config import ToolConfig |
|
from aworld.config.conf import AgentConfig, ConfigDict, ContextRuleConfig, ModelConfig, OptimizationConfig, \ |
|
LlmCompressionConfig |
|
from aworld.core.agent.agent_desc import get_agent_desc |
|
from aworld.core.agent.base import BaseAgent, AgentResult, is_agent_by_name, is_agent |
|
from aworld.core.common import Observation, ActionModel |
|
from aworld.core.context.base import AgentContext |
|
from aworld.core.context.base import Context |
|
from aworld.core.context.processor.prompt_processor import PromptProcessor |
|
from aworld.core.event import eventbus |
|
from aworld.core.event.base import Message, ToolMessage, Constants, AgentMessage |
|
from aworld.core.tool.base import ToolFactory, AsyncTool, Tool |
|
from aworld.core.memory import MemoryItem, MemoryConfig |
|
from aworld.core.tool.tool_desc import get_tool_desc |
|
from aworld.logs.util import logger, color_log, Color, trace_logger |
|
from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform |
|
from aworld.memory.main import MemoryFactory |
|
from aworld.models.llm import get_llm_model, call_llm_model, acall_llm_model, acall_llm_model_stream |
|
from aworld.models.model_response import ModelResponse, ToolCall |
|
from aworld.models.utils import tool_desc_transform, agent_desc_transform |
|
from aworld.output import Outputs |
|
from aworld.output.base import StepOutput, MessageOutput |
|
from aworld.runners.hook.hook_factory import HookFactory |
|
from aworld.runners.hook.hooks import HookPoint |
|
from aworld.utils.common import sync_exec, nest_dict_counter |
|
|
|
|
|
class Agent(BaseAgent[Observation, List[ActionModel]]): |
|
"""Basic agent for unified protocol within the framework.""" |
|
|
|
def __init__(self, |
|
conf: Union[Dict[str, Any], ConfigDict, AgentConfig], |
|
resp_parse_func: Callable[..., Any] = None, |
|
**kwargs): |
|
"""A api class implementation of agent, using the `Observation` and `List[ActionModel]` protocols. |
|
|
|
Args: |
|
conf: Agent config, supported AgentConfig, ConfigDict or dict. |
|
resp_parse_func: Response parse function for the agent standard output, transform llm response. |
|
""" |
|
super(Agent, self).__init__(conf, **kwargs) |
|
conf = self.conf |
|
self.model_name = conf.llm_config.llm_model_name if conf.llm_config.llm_model_name else conf.llm_model_name |
|
self._llm = None |
|
self.memory = MemoryFactory.from_config(MemoryConfig(provider="inmemory")) |
|
self.system_prompt: str = kwargs.pop("system_prompt") if kwargs.get("system_prompt") else conf.system_prompt |
|
self.agent_prompt: str = kwargs.get("agent_prompt") if kwargs.get("agent_prompt") else conf.agent_prompt |
|
|
|
self.event_driven = kwargs.pop('event_driven', conf.get('event_driven', False)) |
|
self.handler: Callable[..., Any] = kwargs.get('handler') |
|
|
|
self.need_reset = kwargs.get('need_reset') if kwargs.get('need_reset') else conf.need_reset |
|
|
|
self.step_reset = kwargs.get('step_reset') if kwargs.get('step_reset') else True |
|
|
|
self.black_tool_actions: Dict[str, List[str]] = kwargs.get("black_tool_actions") if kwargs.get( |
|
"black_tool_actions") else conf.get('black_tool_actions', {}) |
|
self.resp_parse_func = resp_parse_func if resp_parse_func else self.response_parse |
|
self.history_messages = kwargs.get("history_messages") if kwargs.get("history_messages") else 100 |
|
self.use_tools_in_prompt = kwargs.get('use_tools_in_prompt', conf.use_tools_in_prompt) |
|
self.context_rule = kwargs.get("context_rule") if kwargs.get("context_rule") else conf.context_rule |
|
self.tools_instances = {} |
|
self.tools_conf = {} |
|
|
|
def reset(self, options: Dict[str, Any]): |
|
super().reset(options) |
|
self.memory = MemoryFactory.from_config( |
|
MemoryConfig(provider=options.pop("memory_store") if options.get("memory_store") else "inmemory")) |
|
|
|
def set_tools_instances(self, tools, tools_conf): |
|
self.tools_instances = tools |
|
self.tools_conf = tools_conf |
|
|
|
@property |
|
def llm(self): |
|
|
|
if self._llm is None: |
|
llm_config = self.conf.llm_config or None |
|
conf = llm_config if llm_config and ( |
|
llm_config.llm_provider or llm_config.llm_base_url or llm_config.llm_api_key or llm_config.llm_model_name) else self.conf |
|
self._llm = get_llm_model(conf) |
|
return self._llm |
|
|
|
def _env_tool(self): |
|
"""Description of agent as tool.""" |
|
return tool_desc_transform(get_tool_desc(), |
|
tools=self.tool_names if self.tool_names else [], |
|
black_tool_actions=self.black_tool_actions) |
|
|
|
def _handoffs_agent_as_tool(self): |
|
"""Description of agent as tool.""" |
|
return agent_desc_transform(get_agent_desc(), |
|
agents=self.handoffs if self.handoffs else []) |
|
|
|
def _mcp_is_tool(self): |
|
"""Description of mcp servers are tools.""" |
|
try: |
|
return sync_exec(sandbox_mcp_tool_desc_transform, self.mcp_servers, self.mcp_config) |
|
except Exception as e: |
|
logger.error(f"mcp_is_tool error: {traceback.format_exc()}") |
|
return [] |
|
|
|
def desc_transform(self): |
|
"""Transform of descriptions of supported tools, agents, and MCP servers in the framework to support function calls of LLM.""" |
|
|
|
|
|
self.tools = self._env_tool() |
|
|
|
self.tools.extend(self._handoffs_agent_as_tool()) |
|
|
|
self.tools.extend(self._mcp_is_tool()) |
|
|
|
self.agent_context.set_tools(self.tools) |
|
return self.tools |
|
|
|
async def async_desc_transform(self): |
|
"""Transform of descriptions of supported tools, agents, and MCP servers in the framework to support function calls of LLM.""" |
|
|
|
|
|
self.tools = self._env_tool() |
|
|
|
self.tools.extend(self._handoffs_agent_as_tool()) |
|
|
|
|
|
if self.sandbox: |
|
sand_box = self.sandbox |
|
mcp_tools = await sand_box.mcpservers.list_tools() |
|
self.tools.extend(mcp_tools) |
|
else: |
|
self.tools.extend(await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config)) |
|
|
|
self.agent_context.set_tools(self.tools) |
|
|
|
def _messages_transform( |
|
self, |
|
observation: Observation, |
|
): |
|
agent_prompt = self.agent_context.agent_prompt |
|
sys_prompt = self.agent_context.sys_prompt |
|
messages = [] |
|
if sys_prompt: |
|
messages.append( |
|
{'role': 'system', 'content': sys_prompt if not self.use_tools_in_prompt else sys_prompt.format( |
|
tool_list=self.tools)}) |
|
|
|
content = observation.content |
|
if agent_prompt and '{task}' in agent_prompt: |
|
content = agent_prompt.format(task=observation.content) |
|
|
|
cur_msg = {'role': 'user', 'content': content} |
|
|
|
|
|
histories = self.memory.get_last_n(self.history_messages) |
|
messages.extend(histories) |
|
|
|
action_results = observation.action_result |
|
if action_results: |
|
for action_result in action_results: |
|
cur_msg['role'] = 'tool' |
|
cur_msg['tool_call_id'] = action_result.tool_id |
|
|
|
agent_info = self.context.context_info.get(self.id()) |
|
if (self.use_tools_in_prompt and "is_use_tool_prompt" in agent_info and "tool_calls" |
|
in agent_info and agent_prompt): |
|
cur_msg['content'] = agent_prompt.format(action_list=agent_info["tool_calls"], |
|
result=content) |
|
|
|
if observation.images: |
|
urls = [{'type': 'text', 'text': content}] |
|
for image_url in observation.images: |
|
urls.append({'type': 'image_url', 'image_url': {"url": image_url}}) |
|
|
|
cur_msg['content'] = urls |
|
messages.append(cur_msg) |
|
|
|
|
|
try: |
|
messages = self._process_messages(messages=messages, agent_context=self.agent_context, context=self.context) |
|
except Exception as e: |
|
logger.warning(f"Failed to process messages in _messages_transform: {e}") |
|
logger.debug(f"Process messages error details: {traceback.format_exc()}") |
|
self.agent_context.update_messages(messages) |
|
return messages |
|
|
|
def messages_transform(self, |
|
content: str, |
|
image_urls: List[str] = None, |
|
**kwargs): |
|
"""Transform the original content to LLM messages of native format. |
|
|
|
Args: |
|
content: User content. |
|
image_urls: List of images encoded using base64. |
|
sys_prompt: Agent system prompt. |
|
max_step: The maximum list length obtained from memory. |
|
Returns: |
|
Message list for LLM. |
|
""" |
|
sys_prompt = self.agent_context.system_prompt |
|
agent_prompt = self.agent_context.agent_prompt |
|
messages = [] |
|
if sys_prompt: |
|
messages.append( |
|
{'role': 'system', 'content': sys_prompt if not self.use_tools_in_prompt else sys_prompt.format( |
|
tool_list=self.tools)}) |
|
|
|
histories = self.memory.get_last_n(self.history_messages) |
|
user_content = content |
|
if not histories and agent_prompt and '{task}' in agent_prompt: |
|
user_content = agent_prompt.format(task=content) |
|
|
|
cur_msg = {'role': 'user', 'content': user_content} |
|
|
|
|
|
|
|
if histories: |
|
|
|
for history in histories: |
|
if not self.use_tools_in_prompt and "tool_calls" in history.metadata and history.metadata['tool_calls']: |
|
messages.append({'role': history.metadata['role'], 'content': history.content, |
|
'tool_calls': [history.metadata["tool_calls"][0]]}) |
|
else: |
|
messages.append({'role': history.metadata['role'], 'content': history.content, |
|
"tool_call_id": history.metadata.get("tool_call_id")}) |
|
|
|
if not self.use_tools_in_prompt and "tool_calls" in histories[-1].metadata and histories[-1].metadata[ |
|
'tool_calls']: |
|
tool_id = histories[-1].metadata["tool_calls"][0].id |
|
if tool_id: |
|
cur_msg['role'] = 'tool' |
|
cur_msg['tool_call_id'] = tool_id |
|
if self.use_tools_in_prompt and "is_use_tool_prompt" in histories[-1].metadata and "tool_calls" in \ |
|
histories[-1].metadata and agent_prompt: |
|
cur_msg['content'] = agent_prompt.format(action_list=histories[-1].metadata["tool_calls"], |
|
result=content) |
|
|
|
if image_urls: |
|
urls = [{'type': 'text', 'text': content}] |
|
for image_url in image_urls: |
|
urls.append({'type': 'image_url', 'image_url': {"url": image_url}}) |
|
|
|
cur_msg['content'] = urls |
|
messages.append(cur_msg) |
|
|
|
|
|
try: |
|
messages = self._process_messages(messages=messages, agent_context=self.agent_context, context=self.context) |
|
except Exception as e: |
|
logger.warning(f"Failed to process messages in messages_transform: {e}") |
|
logger.debug(f"Process messages error details: {traceback.format_exc()}") |
|
self.agent_context.set_messages(messages) |
|
return messages |
|
|
|
def use_tool_list(self, resp: ModelResponse) -> List[Dict[str, Any]]: |
|
tool_list = [] |
|
try: |
|
if resp and hasattr(resp, 'content') and resp.content: |
|
content = resp.content.strip() |
|
else: |
|
return tool_list |
|
content = content.replace('\n', '').replace('\r', '') |
|
response_json = json.loads(content) |
|
if "use_tool_list" in response_json: |
|
use_tool_list = response_json["use_tool_list"] |
|
if use_tool_list: |
|
for use_tool in use_tool_list: |
|
tool_name = use_tool["tool"] |
|
arguments = use_tool["arguments"] |
|
if tool_name and arguments: |
|
tool_list.append(use_tool) |
|
|
|
return tool_list |
|
except Exception as e: |
|
logger.debug(f"tool_parse error, content: {resp.content}, \nerror msg: {traceback.format_exc()}") |
|
return tool_list |
|
|
|
def response_parse(self, resp: ModelResponse) -> AgentResult: |
|
"""Default parse response by LLM.""" |
|
results = [] |
|
if not resp: |
|
logger.warning("LLM no valid response!") |
|
return AgentResult(actions=[], current_state=None) |
|
|
|
use_tool_list = self.use_tool_list(resp) |
|
is_call_tool = False |
|
content = '' if resp.content is None else resp.content |
|
if resp.tool_calls: |
|
is_call_tool = True |
|
for tool_call in resp.tool_calls: |
|
full_name: str = tool_call.function.name |
|
if not full_name: |
|
logger.warning("tool call response no tool name.") |
|
continue |
|
try: |
|
params = json.loads(tool_call.function.arguments) |
|
except: |
|
logger.warning(f"{tool_call.function.arguments} parse to json fail.") |
|
params = {} |
|
|
|
names = full_name.split("__") |
|
tool_name = names[0] |
|
if is_agent_by_name(tool_name): |
|
param_info = params.get('content', "") + ' ' + params.get('info', '') |
|
results.append(ActionModel(tool_name=tool_name, |
|
tool_id=tool_call.id, |
|
agent_name=self.id(), |
|
params=params, |
|
policy_info=content + param_info)) |
|
else: |
|
action_name = '__'.join(names[1:]) if len(names) > 1 else '' |
|
results.append(ActionModel(tool_name=tool_name, |
|
tool_id=tool_call.id, |
|
action_name=action_name, |
|
agent_name=self.id(), |
|
params=params, |
|
policy_info=content)) |
|
elif use_tool_list and len(use_tool_list) > 0: |
|
is_call_tool = True |
|
for use_tool in use_tool_list: |
|
full_name = use_tool["tool"] |
|
if not full_name: |
|
logger.warning("tool call response no tool name.") |
|
continue |
|
params = use_tool["arguments"] |
|
if not params: |
|
logger.warning("tool call response no tool params.") |
|
continue |
|
names = full_name.split("__") |
|
tool_name = names[0] |
|
if is_agent_by_name(tool_name): |
|
param_info = params.get('content', "") + ' ' + params.get('info', '') |
|
results.append(ActionModel(tool_name=tool_name, |
|
tool_id=use_tool.get('id'), |
|
agent_name=self.id(), |
|
params=params, |
|
policy_info=content + param_info)) |
|
else: |
|
action_name = '__'.join(names[1:]) if len(names) > 1 else '' |
|
results.append(ActionModel(tool_name=tool_name, |
|
tool_id=use_tool.get('id'), |
|
action_name=action_name, |
|
agent_name=self.id(), |
|
params=params, |
|
policy_info=content)) |
|
else: |
|
if content: |
|
content = content.replace("```json", "").replace("```", "") |
|
|
|
results.append(ActionModel(agent_name=self.id(), policy_info=content)) |
|
return AgentResult(actions=results, current_state=None, is_call_tool=is_call_tool) |
|
|
|
def _log_messages(self, messages: List[Dict[str, Any]]) -> None: |
|
"""Log the sequence of messages for debugging purposes""" |
|
logger.info(f"[agent] Invoking LLM with {len(messages)} messages:") |
|
for i, msg in enumerate(messages): |
|
prefix = msg.get('role') |
|
logger.info(f"[agent] Message {i + 1}: {prefix} ===================================") |
|
if isinstance(msg['content'], list): |
|
for item in msg['content']: |
|
if item.get('type') == 'text': |
|
logger.info(f"[agent] Text content: {item.get('text')}") |
|
elif item.get('type') == 'image_url': |
|
image_url = item.get('image_url', {}).get('url', '') |
|
if image_url.startswith('data:image'): |
|
logger.info(f"[agent] Image: [Base64 image data]") |
|
else: |
|
logger.info(f"[agent] Image URL: {image_url[:30]}...") |
|
else: |
|
content = str(msg['content']) |
|
chunk_size = 500 |
|
for j in range(0, len(content), chunk_size): |
|
chunk = content[j:j + chunk_size] |
|
if j == 0: |
|
logger.info(f"[agent] Content: {chunk}") |
|
else: |
|
logger.info(f"[agent] Content (continued): {chunk}") |
|
|
|
if 'tool_calls' in msg and msg['tool_calls']: |
|
for tool_call in msg.get('tool_calls'): |
|
if isinstance(tool_call, dict): |
|
logger.info(f"[agent] Tool call: {tool_call.get('name')} - ID: {tool_call.get('id')}") |
|
args = str(tool_call.get('args', {}))[:1000] |
|
logger.info(f"[agent] Tool args: {args}...") |
|
elif isinstance(tool_call, ToolCall): |
|
logger.info(f"[agent] Tool call: {tool_call.function.name} - ID: {tool_call.id}") |
|
args = str(tool_call.function.arguments)[:1000] |
|
logger.info(f"[agent] Tool args: {args}...") |
|
|
|
def _agent_result(self, actions: List[ActionModel], caller: str): |
|
if not actions: |
|
raise Exception(f'{self.id()} no action decision has been made.') |
|
|
|
tools = OrderedDict() |
|
agents = [] |
|
for action in actions: |
|
if is_agent(action): |
|
agents.append(action) |
|
else: |
|
if action.tool_name not in tools: |
|
tools[action.tool_name] = [] |
|
tools[action.tool_name].append(action) |
|
|
|
_group_name = None |
|
|
|
if (agents and tools) or len(agents) > 1 or len(tools) > 1: |
|
_group_name = f"{self.id()}_{uuid.uuid1().hex}" |
|
|
|
|
|
if _group_name: |
|
logger.warning(f"more than one agent an tool causing confusion, will choose the first one. {agents}") |
|
agents = [agents[0]] if agents else [] |
|
for _, v in tools.items(): |
|
actions = v |
|
break |
|
|
|
if agents: |
|
return AgentMessage(payload=actions, |
|
caller=caller, |
|
sender=self.id(), |
|
receiver=actions[0].tool_name, |
|
session_id=self.context.session_id if self.context else "", |
|
headers={"context": self.context}) |
|
else: |
|
return ToolMessage(payload=actions, |
|
caller=caller, |
|
sender=self.id(), |
|
receiver=actions[0].tool_name, |
|
session_id=self.context.session_id if self.context else "", |
|
headers={"context": self.context}) |
|
|
|
def post_run(self, policy_result: List[ActionModel], policy_input: Observation) -> Message: |
|
return self._agent_result( |
|
policy_result, |
|
policy_input.from_agent_name if policy_input.from_agent_name else policy_input.observer |
|
) |
|
|
|
async def async_post_run(self, policy_result: List[ActionModel], policy_input: Observation) -> Message: |
|
return self._agent_result( |
|
policy_result, |
|
policy_input.from_agent_name if policy_input.from_agent_name else policy_input.observer |
|
) |
|
|
|
def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: |
|
"""The strategy of an agent can be to decide which tools to use in the environment, or to delegate tasks to other agents. |
|
|
|
Args: |
|
observation: The state observed from tools in the environment. |
|
info: Extended information is used to assist the agent to decide a policy. |
|
|
|
Returns: |
|
ActionModel sequence from agent policy |
|
""" |
|
output = None |
|
if kwargs.get("output") and isinstance(kwargs.get("output"), StepOutput): |
|
output = kwargs["output"] |
|
|
|
|
|
step = kwargs.get("step", 0) |
|
exp_id = kwargs.get("exp_id", None) |
|
source_span = trace.get_current_span() |
|
|
|
if hasattr(observation, 'context') and observation.context: |
|
self.task_histories = observation.context |
|
|
|
try: |
|
self._run_hooks_sync(self.context, HookPoint.PRE_LLM_CALL) |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
|
|
self._finished = False |
|
self.desc_transform() |
|
images = observation.images if self.conf.use_vision else None |
|
if self.conf.use_vision and not images and observation.image: |
|
images = [observation.image] |
|
observation.images = images |
|
messages = self.messages_transform(content=observation.content, |
|
image_urls=observation.images) |
|
|
|
self._log_messages(messages) |
|
self.memory.add(MemoryItem( |
|
content=messages[-1]['content'], |
|
metadata={ |
|
"role": messages[-1]['role'], |
|
"agent_name": self.id(), |
|
"tool_call_id": messages[-1].get("tool_call_id") |
|
} |
|
)) |
|
|
|
llm_response = None |
|
span_name = f"llm_call_{exp_id}" |
|
serializable_messages = self._to_serializable(messages) |
|
with trace.span(span_name) as llm_span: |
|
llm_span.set_attributes({ |
|
"exp_id": exp_id, |
|
"step": step, |
|
"messages": json.dumps(serializable_messages, ensure_ascii=False) |
|
}) |
|
if source_span: |
|
source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) |
|
|
|
try: |
|
llm_response = call_llm_model( |
|
self.llm, |
|
messages=messages, |
|
model=self.model_name, |
|
temperature=self.conf.llm_config.llm_temperature, |
|
tools=self.tools if not self.use_tools_in_prompt and self.tools else None |
|
) |
|
|
|
logger.info(f"Execute response: {llm_response.message}") |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
raise e |
|
finally: |
|
if llm_response: |
|
|
|
self.update_context_usage(used_context_length=llm_response.usage['total_tokens']) |
|
|
|
self.update_llm_output(llm_response) |
|
|
|
use_tools = self.use_tool_list(llm_response) |
|
is_use_tool_prompt = len(use_tools) > 0 |
|
if llm_response.error: |
|
logger.info(f"llm result error: {llm_response.error}") |
|
else: |
|
info = { |
|
"role": "assistant", |
|
"agent_name": self.id(), |
|
"tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, |
|
"is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False |
|
} |
|
self.memory.add(MemoryItem( |
|
content=llm_response.content, |
|
metadata=info |
|
)) |
|
|
|
self.context.context_info[self.id()] = info |
|
else: |
|
logger.error(f"{self.id()} failed to get LLM response") |
|
raise RuntimeError(f"{self.id()} failed to get LLM response") |
|
|
|
try: |
|
self._run_hooks_sync(self.context, HookPoint.POST_LLM_CALL) |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
|
|
agent_result = sync_exec(self.resp_parse_func, llm_response) |
|
if not agent_result.is_call_tool: |
|
self._finished = True |
|
|
|
if output: |
|
output.add_part(MessageOutput(source=llm_response, json_parse=False)) |
|
output.mark_finished() |
|
return agent_result.actions |
|
|
|
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: |
|
"""The strategy of an agent can be to decide which tools to use in the environment, or to delegate tasks to other agents. |
|
|
|
Args: |
|
observation: The state observed from tools in the environment. |
|
info: Extended information is used to assist the agent to decide a policy. |
|
|
|
Returns: |
|
ActionModel sequence from agent policy |
|
""" |
|
outputs = None |
|
if kwargs.get("outputs") and isinstance(kwargs.get("outputs"), Outputs): |
|
outputs = kwargs.get("outputs") |
|
|
|
|
|
source_span = trace.get_current_span() |
|
|
|
if hasattr(observation, 'context') and observation.context: |
|
self.task_histories = observation.context |
|
|
|
try: |
|
events = [] |
|
async for event in self.run_hooks(self.context, HookPoint.PRE_LLM_CALL): |
|
events.append(event) |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
|
|
self._finished = False |
|
messages = await self._prepare_llm_input(observation, info, **kwargs) |
|
|
|
serializable_messages = self._to_serializable(messages) |
|
llm_response = None |
|
if source_span: |
|
source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) |
|
try: |
|
llm_response = await self._call_llm_model(observation, messages, info, **kwargs) |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
raise e |
|
finally: |
|
if llm_response: |
|
|
|
self.update_context_usage(used_context_length=llm_response.usage['total_tokens']) |
|
|
|
self.update_llm_output(llm_response) |
|
|
|
use_tools = self.use_tool_list(llm_response) |
|
is_use_tool_prompt = len(use_tools) > 0 |
|
if llm_response.error: |
|
logger.info(f"llm result error: {llm_response.error}") |
|
else: |
|
self.memory.add(MemoryItem( |
|
content=llm_response.content, |
|
metadata={ |
|
"role": "assistant", |
|
"agent_name": self.id(), |
|
"tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, |
|
"is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False |
|
} |
|
)) |
|
else: |
|
logger.error(f"{self.id()} failed to get LLM response") |
|
raise RuntimeError(f"{self.id()} failed to get LLM response") |
|
|
|
try: |
|
events = [] |
|
async for event in self.run_hooks(self.context, HookPoint.POST_LLM_CALL): |
|
events.append(event) |
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
|
|
agent_result = sync_exec(self.resp_parse_func, llm_response) |
|
if not agent_result.is_call_tool: |
|
self._finished = True |
|
return agent_result.actions |
|
|
|
def _to_serializable(self, obj): |
|
if isinstance(obj, dict): |
|
return {k: self._to_serializable(v) for k, v in obj.items()} |
|
elif isinstance(obj, list): |
|
return [self._to_serializable(i) for i in obj] |
|
elif hasattr(obj, "to_dict"): |
|
return obj.to_dict() |
|
elif hasattr(obj, "model_dump"): |
|
return obj.model_dump() |
|
elif hasattr(obj, "dict"): |
|
return obj.dict() |
|
else: |
|
return obj |
|
|
|
async def llm_and_tool_execution(self, observation: Observation, messages: List[Dict[str, str]] = [], |
|
info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: |
|
"""Perform combined LLM call and tool execution operations. |
|
|
|
Args: |
|
observation: The state observed from the environment |
|
info: Extended information to assist the agent in decision-making |
|
**kwargs: Other parameters |
|
|
|
Returns: |
|
ActionModel sequence. If a tool is executed, includes the tool execution result. |
|
""" |
|
|
|
llm_response = await self._call_llm_model(observation, messages, info, **kwargs) |
|
if llm_response: |
|
use_tools = self.use_tool_list(llm_response) |
|
is_use_tool_prompt = len(use_tools) > 0 |
|
if llm_response.error: |
|
logger.info(f"llm result error: {llm_response.error}") |
|
else: |
|
self.memory.add(MemoryItem( |
|
content=llm_response.content, |
|
metadata={ |
|
"role": "assistant", |
|
"agent_name": self.id(), |
|
"tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, |
|
"is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False |
|
} |
|
)) |
|
else: |
|
logger.error(f"{self.id()} failed to get LLM response") |
|
raise RuntimeError(f"{self.id()} failed to get LLM response") |
|
|
|
agent_result = sync_exec(self.resp_parse_func, llm_response) |
|
if not agent_result.is_call_tool: |
|
self._finished = True |
|
return agent_result.actions |
|
else: |
|
result = await self._execute_tool(agent_result.actions) |
|
return result |
|
|
|
async def _prepare_llm_input(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs): |
|
"""Prepare LLM input |
|
Args: |
|
observation: The state observed from the environment |
|
info: Extended information to assist the agent in decision-making |
|
**kwargs: Other parameters |
|
""" |
|
await self.async_desc_transform() |
|
images = observation.images if self.conf.use_vision else None |
|
if self.conf.use_vision and not images and observation.image: |
|
images = [observation.image] |
|
messages = self.messages_transform(content=observation.content, |
|
image_urls=images) |
|
|
|
self._log_messages(messages) |
|
self.memory.add(MemoryItem( |
|
content=messages[-1]['content'], |
|
metadata={ |
|
"role": messages[-1]['role'], |
|
"agent_name": self.id(), |
|
"tool_call_id": messages[-1].get("tool_call_id") |
|
} |
|
)) |
|
|
|
return messages |
|
|
|
def _process_messages(self, messages: List[Dict[str, Any]], agent_context: AgentContext = None, |
|
context: Context = None) -> Message: |
|
origin_messages = messages |
|
st = time.time() |
|
with trace.span(f"llm_context_process", attributes={ |
|
"start_time": st |
|
}) as compress_span: |
|
if agent_context.context_rule is None: |
|
logger.debug('debug|skip process_messages context_rule is None') |
|
return messages |
|
origin_len = compressed_len = len(str(messages)) |
|
origin_messages_count = truncated_messages_count = len(messages) |
|
try: |
|
prompt_processor = PromptProcessor(agent_context) |
|
result = prompt_processor.process_messages(messages, context) |
|
messages = result.processed_messages |
|
|
|
compressed_len = len(str(messages)) |
|
truncated_messages_count = len(messages) |
|
logger.debug( |
|
f'debug|llm_context_process|{origin_len}|{compressed_len}|{origin_messages_count}|{truncated_messages_count}|\n|{origin_messages}\n|{messages}') |
|
return messages |
|
finally: |
|
compress_span.set_attributes({ |
|
"end_time": time.time(), |
|
"duration": time.time() - st, |
|
|
|
"origin_messages_count": origin_messages_count, |
|
"truncated_messages_count": truncated_messages_count, |
|
"truncated_ratio": round(truncated_messages_count / origin_messages_count, 2), |
|
|
|
"origin_len": origin_len, |
|
"compressed_len": compressed_len, |
|
"compress_ratio": round(compressed_len / origin_len, 2) |
|
}) |
|
|
|
async def _call_llm_model(self, observation: Observation, messages: List[Dict[str, str]] = [], |
|
info: Dict[str, Any] = {}, **kwargs) -> ModelResponse: |
|
"""Perform LLM call |
|
Args: |
|
observation: The state observed from the environment |
|
info: Extended information to assist the agent in decision-making |
|
**kwargs: Other parameters |
|
Returns: |
|
LLM response |
|
""" |
|
outputs = None |
|
if kwargs.get("outputs") and isinstance(kwargs.get("outputs"), Outputs): |
|
outputs = kwargs.get("outputs") |
|
if not messages: |
|
messages = await self._prepare_llm_input(observation, self.agent_context, **kwargs) |
|
|
|
llm_response = None |
|
source_span = trace.get_current_span() |
|
serializable_messages = self._to_serializable(messages) |
|
|
|
if source_span: |
|
source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) |
|
|
|
try: |
|
stream_mode = kwargs.get("stream", False) |
|
if stream_mode: |
|
llm_response = ModelResponse(id="", model="", content="", tool_calls=[]) |
|
resp_stream = acall_llm_model_stream( |
|
self.llm, |
|
messages=messages, |
|
model=self.model_name, |
|
temperature=self.conf.llm_config.llm_temperature, |
|
tools=self.tools if not self.use_tools_in_prompt and self.tools else None, |
|
stream=True |
|
) |
|
|
|
async def async_call_llm(resp_stream, json_parse=False): |
|
llm_resp = ModelResponse(id="", model="", content="", tool_calls=[]) |
|
|
|
|
|
async def async_generator(): |
|
async for chunk in resp_stream: |
|
if chunk.content: |
|
llm_resp.content += chunk.content |
|
yield chunk.content |
|
if chunk.tool_calls: |
|
llm_resp.tool_calls.extend(chunk.tool_calls) |
|
if chunk.error: |
|
llm_resp.error = chunk.error |
|
llm_resp.id = chunk.id |
|
llm_resp.model = chunk.model |
|
llm_resp.usage = nest_dict_counter(llm_resp.usage, chunk.usage) |
|
|
|
return MessageOutput(source=async_generator(), json_parse=json_parse), llm_resp |
|
|
|
output, response = await async_call_llm(resp_stream) |
|
llm_response = response |
|
|
|
if eventbus is not None and resp_stream: |
|
output_message = Message( |
|
category=Constants.OUTPUT, |
|
payload=output, |
|
sender=self.id(), |
|
session_id=self.context.session_id if self.context else "", |
|
headers={"context": self.context} |
|
) |
|
await eventbus.publish(output_message) |
|
elif not self.event_driven and outputs: |
|
outputs.add_output(output) |
|
|
|
else: |
|
llm_response = await acall_llm_model( |
|
self.llm, |
|
messages=messages, |
|
model=self.model_name, |
|
temperature=self.conf.llm_config.llm_temperature, |
|
tools=self.tools if not self.use_tools_in_prompt and self.tools else None, |
|
stream=kwargs.get("stream", False) |
|
) |
|
if eventbus is None: |
|
logger.warn("=============== eventbus is none ============") |
|
if eventbus is not None and llm_response: |
|
await eventbus.publish(Message( |
|
category=Constants.OUTPUT, |
|
payload=llm_response, |
|
sender=self.id(), |
|
session_id=self.context.session_id if self.context else "", |
|
headers={"context": self.context} |
|
)) |
|
elif not self.event_driven and outputs: |
|
outputs.add_output(MessageOutput(source=llm_response, json_parse=False)) |
|
|
|
logger.info(f"Execute response: {json.dumps(llm_response.to_dict(), ensure_ascii=False)}") |
|
|
|
|
|
except Exception as e: |
|
logger.warn(traceback.format_exc()) |
|
raise e |
|
finally: |
|
return llm_response |
|
|
|
async def _execute_tool(self, actions: List[ActionModel]) -> Any: |
|
"""Execute tool calls |
|
|
|
Args: |
|
action: The action(s) to execute |
|
|
|
Returns: |
|
The result of tool execution |
|
""" |
|
tool_actions = [] |
|
for act in actions: |
|
if is_agent(act): |
|
continue |
|
else: |
|
tool_actions.append(act) |
|
|
|
msg = None |
|
terminated = False |
|
|
|
tool_mapping = dict() |
|
reward = 0.0 |
|
|
|
for act in tool_actions: |
|
if not self.tools_instances or (self.tools_instances and act.tool_name not in self.tools): |
|
|
|
conf = self.tools_conf.get(act.tool_name) |
|
if not conf: |
|
conf = ToolConfig(exit_on_failure=self.task.conf.get('exit_on_failure')) |
|
tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False) |
|
if isinstance(tool, Tool): |
|
tool.reset() |
|
elif isinstance(tool, AsyncTool): |
|
await tool.reset() |
|
tool_mapping[act.tool_name] = [] |
|
self.tools_instances[act.tool_name] = tool |
|
if act.tool_name not in tool_mapping: |
|
tool_mapping[act.tool_name] = [] |
|
tool_mapping[act.tool_name].append(act) |
|
|
|
observation = None |
|
|
|
for tool_name, action in tool_mapping.items(): |
|
|
|
if isinstance(self.tools_instances[tool_name], Tool): |
|
message = self.tools_instances[tool_name].step(action) |
|
elif isinstance(self.tools_instances[tool_name], AsyncTool): |
|
|
|
message = await self.tools_instances[tool_name].step(action, agent=self) |
|
else: |
|
logger.warning(f"Unsupported tool type: {self.tools_instances[tool_name]}") |
|
continue |
|
|
|
observation, reward, terminated, _, info = message.payload |
|
|
|
|
|
if info.get("exception"): |
|
color_log(f"Agent {self.id()} _execute_tool failed with exception: {info['exception']}", |
|
color=Color.red) |
|
msg = f"Agent {self.id()} _execute_tool failed with exception: {info['exception']}" |
|
logger.info(f"Agent {self.id()} _execute_tool finished by tool action: {action}.") |
|
log_ob = Observation(content='' if observation.content is None else observation.content, |
|
action_result=observation.action_result) |
|
trace_logger.info(f"{tool_name} observation: {log_ob}", color=Color.green) |
|
self.memory.add(MemoryItem( |
|
content=observation.content, |
|
metadata={ |
|
"role": "tool", |
|
"agent_name": self.id(), |
|
"tool_call_id": action[0].tool_id |
|
} |
|
)) |
|
return [ActionModel(agent_name=self.id(), policy_info=observation.content)] |
|
|
|
def _init_context(self, context: Context): |
|
super()._init_context(context) |
|
|
|
llm_config = self.conf.llm_config |
|
context_rule = self.context_rule |
|
if context_rule is None: |
|
context_rule = ContextRuleConfig( |
|
optimization_config=OptimizationConfig( |
|
enabled=True, |
|
max_token_budget_ratio=1.0 |
|
), |
|
llm_compression_config=LlmCompressionConfig( |
|
enabled=False |
|
) |
|
) |
|
self.agent_context.set_model_config(llm_config) |
|
self.agent_context.context_rule = context_rule |
|
self.agent_context.system_prompt = self.system_prompt |
|
self.agent_context.agent_prompt = self.agent_prompt |
|
logger.debug(f'init_context llm_agent {self.name()} {self.agent_context} {self.conf} {self.context_rule}') |
|
|
|
def update_system_prompt(self, system_prompt: str): |
|
self.system_prompt = system_prompt |
|
self.agent_context.system_prompt = system_prompt |
|
logger.info(f"Agent {self.name()} system_prompt updated") |
|
|
|
def update_agent_prompt(self, agent_prompt: str): |
|
self.agent_prompt = agent_prompt |
|
self.agent_context.agent_prompt = agent_prompt |
|
logger.info(f"Agent {self.name()} agent_prompt updated") |
|
|
|
def update_context_rule(self, context_rule: ContextRuleConfig): |
|
self.agent_context.context_rule = context_rule |
|
logger.info(f"Agent {self.name()} context_rule updated") |
|
|
|
def update_context_usage(self, used_context_length: int = None, total_context_length: int = None): |
|
self.agent_context.update_context_usage(used_context_length, total_context_length) |
|
logger.debug(f"Agent {self.name()} context usage updated: {self.agent_context.context_usage}") |
|
|
|
def update_llm_output(self, llm_response: ModelResponse): |
|
self.agent_context.set_llm_output(llm_response) |
|
logger.debug(f"Agent {self.name()} llm output updated: {self.agent_context.llm_output}") |
|
|
|
async def run_hooks(self, context: Context, hook_point: str): |
|
"""Execute hooks asynchronously""" |
|
from aworld.runners.hook.hook_factory import HookFactory |
|
from aworld.core.event.base import Message |
|
|
|
|
|
all_hooks = HookFactory.hooks(hook_point) |
|
hooks = all_hooks.get(hook_point, []) |
|
|
|
for hook in hooks: |
|
try: |
|
|
|
message = Message( |
|
category="agent_hook", |
|
payload=None, |
|
sender=self.id(), |
|
session_id=context.session_id if hasattr(context, 'session_id') else None, |
|
headers={"context": self.context} |
|
) |
|
|
|
|
|
msg = await hook.exec(message, context) |
|
if msg: |
|
logger.debug(f"Hook {hook.point()} executed successfully") |
|
yield msg |
|
except Exception as e: |
|
logger.warning(f"Hook {hook.point()} execution failed: {traceback.format_exc()}") |
|
|
|
def _run_hooks_sync(self, context: Context, hook_point: str): |
|
"""Execute hooks synchronously""" |
|
|
|
try: |
|
sync_exec(self.run_hooks, context, hook_point) |
|
except Exception as e: |
|
logger.warn(f"Failed to execute hooks for {hook_point}: {traceback.format_exc()}") |
|
|