File size: 10,039 Bytes
5fc6c27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import copy
import json
import traceback
from typing import Dict, Any, List, Union

from examples.tools.common import Agents
from aworld.core.agent.base import AgentResult
from aworld.agents.llm_agent import Agent
from aworld.models.llm import call_llm_model
from aworld.config.conf import AgentConfig, ConfigDict
from aworld.core.common import Observation, ActionModel
from aworld.logs.util import logger
from examples.plan_execute.prompts import *
from examples.plan_execute.utils import extract_pattern


class ExecuteAgent(Agent):
    def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
        super(ExecuteAgent, self).__init__(conf, **kwargs)

    def id(self) -> str:
        return Agents.EXECUTE.value

    def reset(self, options: Dict[str, Any]):
        """Execute agent reset need query task as input."""
        super().reset(options)

        self.system_prompt = execute_system_prompt.format(task=self.task)
        self.step_reset = False

    async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
        List[ActionModel], None]:
        await self.async_desc_transform()
        return self._common(observation, info)

    def policy(self,
               observation: Observation,
               info: Dict[str, Any] = None,
               **kwargs) -> List[ActionModel] | None:
        self.desc_transform()
        return self._common(observation, info)

    def _common(self, observation, info):
        self._finished = False
        content = observation.content

        llm_result = None
        ## build input of llm
        input_content = [
            {'role': 'system', 'content': self.system_prompt},
        ]
        for traj in self.trajectory:
            # Handle multiple messages in content
            if isinstance(traj[0].content, list):
                input_content.extend(traj[0].content)
            else:
                input_content.append(traj[0].content)

            if traj[-1].tool_calls is not None:
                input_content.append(
                    {'role': 'assistant', 'content': '', 'tool_calls': traj[-1].tool_calls})
            else:
                input_content.append({'role': 'assistant', 'content': traj[-1].content})

        if content is None:
            content = observation.action_result[0].error
        if not self.trajectory:
            new_messages = [{"role": "user", "content": content}]
            input_content.extend(new_messages)
        else:
            # Collect existing tool_call_ids from input_content
            existing_tool_call_ids = {
                msg.get("tool_call_id") for msg in input_content
                if msg.get("role") == "tool" and msg.get("tool_call_id")
            }

            new_messages = []
            for traj in self.trajectory:
                if traj[-1].tool_calls is not None:
                    # Handle multiple tool calls
                    for tool_call in traj[-1].tool_calls:
                        # Only add if this tool_call_id doesn't exist in input_content
                        if tool_call.id not in existing_tool_call_ids:
                            new_messages.append({
                                "role": "tool",
                                "content": content,
                                "tool_call_id": tool_call.id
                            })
            if new_messages:
                input_content.extend(new_messages)
            else:
                input_content.append({"role": "user", "content": content})

            # Validate tool_calls and tool messages pairing
            assistant_tool_calls = []
            tool_responses = []
            for msg in input_content:
                if msg.get("role") == "assistant" and msg.get("tool_calls"):
                    assistant_tool_calls.extend(msg["tool_calls"])
                elif msg.get("role") == "tool":
                    tool_responses.append(msg.get("tool_call_id"))

            # Check if all tool_calls have corresponding responses
            tool_call_ids = {call.id for call in assistant_tool_calls}
            tool_response_ids = set(tool_responses)
            if tool_call_ids != tool_response_ids:
                missing_calls = tool_call_ids - tool_response_ids
                extra_responses = tool_response_ids - tool_call_ids
                error_msg = f"Tool calls and responses mismatch. Missing responses for tool_calls: {missing_calls}, Extra responses: {extra_responses}"
                logger.error(error_msg)
                raise ValueError(error_msg)

        tool_calls = []
        try:
            llm_result = call_llm_model(self.llm, input_content, model=self.model_name,
                                        tools=self.tools, temperature=0)
            logger.info(f"Execute response: {llm_result.message}")
            res = self.response_parse(llm_result)
            content = res.actions[0].policy_info
            tool_calls = llm_result.tool_calls
        except Exception as e:
            logger.warning(traceback.format_exc())
        finally:
            if llm_result:
                ob = copy.deepcopy(observation)
                ob.content = new_messages
                self.trajectory.append((ob, info, llm_result))
            else:
                logger.warning("no result to record!")

        res = []
        if tool_calls:
            for tool_call in tool_calls:
                tool_action_name: str = tool_call.function.name
                if not tool_action_name:
                    continue

                names = tool_action_name.split("__")
                tool_name = names[0]
                action_name = '__'.join(names[1:]) if len(names) > 1 else ''
                params = json.loads(tool_call.function.arguments)
                res.append(ActionModel(agent_name=Agents.EXECUTE.value,
                                       tool_name=tool_name,
                                       action_name=action_name,
                                       params=params))

        if res:
            res[0].policy_info = content
            self._finished = False
        elif content:
            policy_info = extract_pattern(content, "final_answer")
            if policy_info:
                res.append(ActionModel(agent_name=Agents.EXECUTE.value,
                                       policy_info=policy_info))
                self._finished = True
            else:
                res.append(ActionModel(agent_name=Agents.EXECUTE.value,
                                       policy_info=content))

        logger.info(f">>> execute result: {res}")

        result = AgentResult(actions=res,
                             current_state=None)
        return result.actions


class PlanAgent(Agent):
    def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
        super(PlanAgent, self).__init__(conf, **kwargs)

    def id(self) -> str:
        return Agents.PLAN.value

    def reset(self, options: Dict[str, Any]):
        """Execute agent reset need query task as input."""
        super().reset(options)

        self.system_prompt = plan_system_prompt.format(task=self.task)
        self.done_prompt = plan_done_prompt.format(task=self.task)
        self.postfix_prompt = plan_postfix_prompt.format(task=self.task)
        self.first_prompt = init_prompt
        self.first = True
        self.step_reset = False

    async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
        List[ActionModel], None]:
        await self.async_desc_transform()
        return self._common(observation, info)

    def policy(self,
               observation: Observation,
               info: Dict[str, Any] = None,
               **kwargs) -> List[ActionModel] | None:
        self._finished = False
        self.desc_transform()
        return self._common(observation, info)

    def _common(self, observation, info):
        llm_result = None
        input_content = [
            {'role': 'system', 'content': self.system_prompt},
        ]
        # build input of llm based history
        for traj in self.trajectory:
            input_content.append({'role': 'user', 'content': traj[0].content})
            # plan agent no tool to call, use content
            input_content.append({'role': 'assistant', 'content': traj[-1].content})

        message = observation.content
        if self.first_prompt:
            message = self.first_prompt
            self.first_prompt = None

        input_content.append({"role": "user", "content": message})
        try:
            llm_result = call_llm_model(self.llm, messages=input_content, model=self.model_name)
            logger.info(f"Plan response: {llm_result.message}")
        except Exception as e:
            logger.warning(traceback.format_exc())
            raise e
        finally:
            if llm_result:
                ob = copy.deepcopy(observation)
                ob.content = message
                self.trajectory.append((ob, info, llm_result))
            else:
                logger.warning("no result to record!")
        res = self.response_parse(llm_result)
        content = res.actions[0].policy_info
        if "TASK_DONE" not in content:
            content += self.done_prompt
        else:
            # The task is done, and the assistant agent need to give the final answer about the original task
            content += self.postfix_prompt
            if not self.first:
                self._finished = True

        self.first = False
        logger.info(f">>> plan result: {content}")
        result = AgentResult(actions=[ActionModel(agent_name=Agents.PLAN.value,
                                                  tool_name=Agents.EXECUTE.value,
                                                  policy_info=content)],
                             current_state=None)
        return result.actions