Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- examples/android/README.md +3 -0
- examples/android/agent.py +252 -0
- examples/android/prompts.py +58 -0
- examples/android/requirements.txt +6 -0
- examples/android/run.py +34 -0
- examples/android/utils.py +277 -0
examples/android/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Android Agents
|
2 |
+
|
3 |
+
Agents specialized in Android device automation.
|
examples/android/agent.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import traceback
|
7 |
+
from typing import Dict, Any, Optional, List, Union
|
8 |
+
|
9 |
+
from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage
|
10 |
+
|
11 |
+
from examples.android.prompts import SYSTEM_PROMPT, LAST_STEP_PROMPT
|
12 |
+
from examples.android.utils import (
|
13 |
+
AgentState,
|
14 |
+
AgentHistory,
|
15 |
+
AgentHistoryList,
|
16 |
+
ActionResult,
|
17 |
+
PolicyMetadata,
|
18 |
+
AgentBrain,
|
19 |
+
Trajectory
|
20 |
+
)
|
21 |
+
from examples.browsers.common import AgentStepInfo
|
22 |
+
from aworld.config.conf import AgentConfig, ConfigDict
|
23 |
+
from aworld.core.agent.base import AgentResult
|
24 |
+
from aworld.agents.llm_agent import Agent
|
25 |
+
from aworld.core.common import Observation, ActionModel, ToolActionInfo
|
26 |
+
from aworld.logs.util import logger
|
27 |
+
from examples.tools.tool_action import AndroidAction
|
28 |
+
|
29 |
+
|
30 |
+
class AndroidAgent(Agent):
|
31 |
+
def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
|
32 |
+
super(AndroidAgent, self).__init__(conf, **kwargs)
|
33 |
+
provider = self.conf.llm_config.llm_provider if self.conf.llm_config.llm_provider else self.conf.llm_provider
|
34 |
+
if self.conf.llm_config.llm_provider:
|
35 |
+
self.conf.llm_config.llm_provider = "chat" + provider
|
36 |
+
else:
|
37 |
+
self.conf.llm_provider = "chat" + provider
|
38 |
+
self.available_actions_desc = self._build_action_prompt()
|
39 |
+
# Settings
|
40 |
+
self.settings = self.conf
|
41 |
+
|
42 |
+
def reset(self, options: Dict[str, Any]):
|
43 |
+
super(AndroidAgent, self).__init__(options)
|
44 |
+
# State
|
45 |
+
self.state = AgentState()
|
46 |
+
# History
|
47 |
+
self.history = AgentHistoryList(history=[])
|
48 |
+
self.trajectory = Trajectory(history=[])
|
49 |
+
|
50 |
+
def _build_action_prompt(self) -> str:
|
51 |
+
def _prompt(info: ToolActionInfo) -> str:
|
52 |
+
s = f'{info.desc}:\n'
|
53 |
+
s += '{' + str(info.name) + ': '
|
54 |
+
if info.input_params:
|
55 |
+
s += str({k: {"title": k, "type": v} for k, v in info.input_params.items()})
|
56 |
+
s += '}'
|
57 |
+
return s
|
58 |
+
|
59 |
+
# Iterate over all android actions
|
60 |
+
val = "\n".join([_prompt(v.value) for k, v in AndroidAction.__members__.items()])
|
61 |
+
return val
|
62 |
+
|
63 |
+
def policy(self,
|
64 |
+
observation: Observation,
|
65 |
+
info: Dict[str, Any] = None,
|
66 |
+
**kwargs) -> Union[List[ActionModel], None]:
|
67 |
+
self._finished = False
|
68 |
+
step_info = AgentStepInfo(number=self.state.n_steps, max_steps=self.conf.max_steps)
|
69 |
+
last_step_msg = None
|
70 |
+
if step_info and step_info.is_last_step():
|
71 |
+
# Add last step warning if needed
|
72 |
+
last_step_msg = HumanMessage(
|
73 |
+
content=LAST_STEP_PROMPT)
|
74 |
+
logger.info('Last step finishing up')
|
75 |
+
|
76 |
+
logger.info(f'[agent] 📍 Step {self.state.n_steps}')
|
77 |
+
step_start_time = time.time()
|
78 |
+
|
79 |
+
try:
|
80 |
+
|
81 |
+
xml_content, base64_img = observation.dom_tree, observation.image
|
82 |
+
|
83 |
+
if xml_content is None:
|
84 |
+
logger.error("[agent] ⚠ Failed to get UI state, stopping task")
|
85 |
+
self.stop()
|
86 |
+
return None
|
87 |
+
|
88 |
+
self.state.last_result = (xml_content, base64_img if base64_img else "")
|
89 |
+
|
90 |
+
logger.info("[agent] 🤖 Analyzing current state with LLM...")
|
91 |
+
a_step_msg = HumanMessage(content=[
|
92 |
+
{
|
93 |
+
"type": "text",
|
94 |
+
"text": f"""
|
95 |
+
Task: {self.task}
|
96 |
+
Current Step: {self.state.n_steps}
|
97 |
+
|
98 |
+
Please analyze the current interface and decide the next action. Please directly return the response in JSON format without any other text or code block markers.
|
99 |
+
"""
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"type": "image_url",
|
103 |
+
"image_url": f"data:image/jpeg;base64,{self.state.image}"
|
104 |
+
}
|
105 |
+
])
|
106 |
+
|
107 |
+
messages = [SystemMessage(content=SYSTEM_PROMPT)]
|
108 |
+
if last_step_msg:
|
109 |
+
messages.append(last_step_msg)
|
110 |
+
messages.append(a_step_msg)
|
111 |
+
|
112 |
+
logger.info(f"[agent] VLM Input last message: {messages[-1]}")
|
113 |
+
llm_result = None
|
114 |
+
try:
|
115 |
+
llm_result = self._do_policy(messages)
|
116 |
+
|
117 |
+
if self.state.stopped or self.state.paused:
|
118 |
+
logger.info('Android agent paused after getting state')
|
119 |
+
return [ActionModel(tool_name='android', action_name="stop")]
|
120 |
+
|
121 |
+
tool_action = llm_result.actions
|
122 |
+
|
123 |
+
step_metadata = PolicyMetadata(
|
124 |
+
start_time=step_start_time,
|
125 |
+
end_time=time.time(),
|
126 |
+
number=self.state.n_steps,
|
127 |
+
input_tokens=1
|
128 |
+
)
|
129 |
+
|
130 |
+
history_item = AgentHistory(
|
131 |
+
result=[ActionResult(success=True)],
|
132 |
+
metadata=step_metadata,
|
133 |
+
content=xml_content,
|
134 |
+
base64_img=base64_img
|
135 |
+
)
|
136 |
+
self.history.history.append(history_item)
|
137 |
+
|
138 |
+
if self.settings.save_history and self.settings.history_path:
|
139 |
+
self.history.save_to_file(self.settings.history_path)
|
140 |
+
|
141 |
+
logger.info(f'📍 Step {self.state.n_steps} starts to execute')
|
142 |
+
|
143 |
+
self.state.n_steps += 1
|
144 |
+
self.state.consecutive_failures = 0
|
145 |
+
return tool_action
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
logger.warning(traceback.format_exc())
|
149 |
+
raise RuntimeError("Android agent encountered exception while making the policy.", e)
|
150 |
+
finally:
|
151 |
+
if llm_result:
|
152 |
+
self.trajectory.add_step(observation, info, llm_result)
|
153 |
+
metadata = PolicyMetadata(
|
154 |
+
number=self.state.n_steps,
|
155 |
+
start_time=step_start_time,
|
156 |
+
end_time=time.time(),
|
157 |
+
input_tokens=1
|
158 |
+
)
|
159 |
+
self._make_history_item(llm_result, observation, metadata)
|
160 |
+
else:
|
161 |
+
logger.warning("no result to record!")
|
162 |
+
|
163 |
+
except json.JSONDecodeError as e:
|
164 |
+
logger.error("[agent] ❌ JSON parsing error")
|
165 |
+
raise
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f"[agent] ❌ Action execution error: {str(e)}")
|
168 |
+
raise
|
169 |
+
|
170 |
+
def _do_policy(self, input_messages: list[BaseMessage]) -> AgentResult:
|
171 |
+
response = self.llm.invoke(input_messages)
|
172 |
+
content = response.content
|
173 |
+
|
174 |
+
if content.startswith("```json"):
|
175 |
+
content = content[7:]
|
176 |
+
if content.startswith("```"):
|
177 |
+
content = content[3:]
|
178 |
+
if content.endswith("```"):
|
179 |
+
content = content[:-3]
|
180 |
+
content = content.strip()
|
181 |
+
|
182 |
+
action_data = json.loads(content)
|
183 |
+
brain_state = AgentBrain(**action_data["current_state"])
|
184 |
+
|
185 |
+
logger.info(f"[agent] ⚠ Eval: {brain_state.evaluation_previous_goal}")
|
186 |
+
logger.info(f"[agent] 🧠 Memory: {brain_state.memory}")
|
187 |
+
logger.info(f"[agent] 🎯 Next goal: {brain_state.next_goal}")
|
188 |
+
|
189 |
+
actions = action_data.get('action')
|
190 |
+
result = []
|
191 |
+
if not actions:
|
192 |
+
actions = action_data.get("actions")
|
193 |
+
|
194 |
+
# print actions
|
195 |
+
logger.info(f"[agent] VLM Output actions: {actions}")
|
196 |
+
for action in actions:
|
197 |
+
action_type = action.get('type')
|
198 |
+
if not action_type:
|
199 |
+
logger.warning(f"Action missing type: {action}")
|
200 |
+
continue
|
201 |
+
|
202 |
+
params = {}
|
203 |
+
if 'type' == action_type:
|
204 |
+
action_type = 'input_text'
|
205 |
+
if 'params' in action:
|
206 |
+
params = action['params']
|
207 |
+
if 'index' in action:
|
208 |
+
params['index'] = action['index']
|
209 |
+
if 'type' in action:
|
210 |
+
params['type'] = action['type']
|
211 |
+
if 'text' in action:
|
212 |
+
params['text'] = action['text']
|
213 |
+
|
214 |
+
action_model = ActionModel(
|
215 |
+
tool_name='android',
|
216 |
+
action_name=action_type,
|
217 |
+
params=params
|
218 |
+
)
|
219 |
+
result.append(action_model)
|
220 |
+
|
221 |
+
return AgentResult(current_state=brain_state, actions=result)
|
222 |
+
|
223 |
+
def _make_history_item(self,
|
224 |
+
model_output: AgentResult | None,
|
225 |
+
state: Observation,
|
226 |
+
metadata: Optional[PolicyMetadata] = None) -> None:
|
227 |
+
if isinstance(state, dict):
|
228 |
+
state = Observation(**state)
|
229 |
+
|
230 |
+
history_item = AgentHistory(
|
231 |
+
model_output=model_output,
|
232 |
+
result=state.action_result,
|
233 |
+
metadata=metadata,
|
234 |
+
content=state.dom_tree,
|
235 |
+
base64_img=state.image
|
236 |
+
)
|
237 |
+
self.state.history.history.append(history_item)
|
238 |
+
|
239 |
+
def pause(self) -> None:
|
240 |
+
"""Pause the agent"""
|
241 |
+
logger.info('🔄 Pausing Agent')
|
242 |
+
self.state.paused = True
|
243 |
+
|
244 |
+
def resume(self) -> None:
|
245 |
+
"""Resume the agent"""
|
246 |
+
logger.info('▶️ Agent resuming')
|
247 |
+
self.state.paused = False
|
248 |
+
|
249 |
+
def stop(self) -> None:
|
250 |
+
"""Stop the agent"""
|
251 |
+
logger.info('⏹️ Agent stopping')
|
252 |
+
self.state.stopped = True
|
examples/android/prompts.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SYSTEM_PROMPT = """
|
2 |
+
You are an Android device automation assistant. Your task is to help users perform various operations on Android devices.
|
3 |
+
You can perform the following actions:
|
4 |
+
1.Tap Element (tap) - Requires parameter: index (element number)
|
5 |
+
2.Input Text (input_text) - Requires parameter: text (text content to input)
|
6 |
+
3.Long Press Element (long_press) - Requires parameter: index (element number)
|
7 |
+
4.Swipe Element (swipe) - Requires parameter: index (element number), params.direction (direction: "up", "down", "left", "right"), params.dist (distance: "short", "medium", "long", optional, default is "medium")
|
8 |
+
5.Task Completion (done) - Requires parameter: success (whether the task was successfully completed, values are true/false)
|
9 |
+
|
10 |
+
Each interactive element has a number. You need to perform operations based on the element numbers displayed on the interface. Element numbers start from 1; 0 is not a valid element number. The current interface's XML and screenshot will be your input. Please carefully analyze the interface elements and choose the correct operation.
|
11 |
+
|
12 |
+
Important Note: Please directly return the response in JSON format without any other text, explanations, or code block markers. The response must be a valid JSON object, formatted as follows:
|
13 |
+
|
14 |
+
{
|
15 |
+
"current_state": {
|
16 |
+
"evaluation_previous_goal": "Analyze the result of the previous step",
|
17 |
+
"memory": "Remember important context information",
|
18 |
+
"next_goal": "The specific goal to execute next"
|
19 |
+
},
|
20 |
+
"action": [
|
21 |
+
{
|
22 |
+
"type": "tap",
|
23 |
+
"index": "Element number"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"type": "input_text",
|
27 |
+
"text": "Text content to input"
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"type": "long_press",
|
31 |
+
"index": "Element number"
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"type": "swipe",
|
35 |
+
"index": "Element number",
|
36 |
+
"params": {
|
37 |
+
"direction": "Swipe direction (up/down/left/right)",
|
38 |
+
"dist": "Swipe distance (short/medium/long, optional)"
|
39 |
+
}
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"type": "done",
|
43 |
+
"success": "Whether the task was successfully completed (true/false)"
|
44 |
+
}
|
45 |
+
]
|
46 |
+
}
|
47 |
+
|
48 |
+
Note:
|
49 |
+
The index must be a valid integer starting from 1
|
50 |
+
Do not add any other text or markers before or after the JSON
|
51 |
+
Ensure the JSON format is entirely correct
|
52 |
+
Each action type must include all necessary required parameters
|
53 |
+
"""
|
54 |
+
|
55 |
+
LAST_STEP_PROMPT = """Now comes your last step. Use only the "done" action now. No other actions - so here your action sequence must have length 1.
|
56 |
+
If the task is not yet fully finished as requested by the user, set success in "done" to false! E.g. if not all steps are fully completed.
|
57 |
+
If the task is fully finished, set success in "done" to true.
|
58 |
+
Include everything you found out for the ultimate task in the done text."""
|
examples/android/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain~=0.3.20
|
2 |
+
langchain-openai~=0.3.8
|
3 |
+
langchain-ollama~=0.2.3
|
4 |
+
langchain-anthropic~=0.3.9
|
5 |
+
langchain-mistralai~=0.2.7
|
6 |
+
langchain-google-genai~=2.1.0
|
examples/android/run.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
from aworld.config import AgentConfig
|
5 |
+
from examples.android.agent import AndroidAgent
|
6 |
+
from examples.tools.common import Agents, Tools
|
7 |
+
from aworld.core.task import Task
|
8 |
+
from aworld.runner import Runners
|
9 |
+
from examples.tools.conf import AndroidToolConfig
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
android_tool_config = AndroidToolConfig(avd_name='8ABX0PHWU',
|
14 |
+
headless=False,
|
15 |
+
max_retry=2)
|
16 |
+
|
17 |
+
agent_config: AgentConfig = AgentConfig(
|
18 |
+
name=Agents.ANDROID.value,
|
19 |
+
llm_provider="openai",
|
20 |
+
llm_model_name="gpt-4o",
|
21 |
+
llm_temperature=1,
|
22 |
+
)
|
23 |
+
agent = AndroidAgent(name=Agents.ANDROID.value, conf=agent_config)
|
24 |
+
|
25 |
+
task = Task(
|
26 |
+
input="""open rednote""",
|
27 |
+
agent=agent,
|
28 |
+
tools_conf={Tools.ANDROID.value, android_tool_config}
|
29 |
+
)
|
30 |
+
Runners.sync_run_task(task)
|
31 |
+
|
32 |
+
|
33 |
+
if __name__ == '__main__':
|
34 |
+
main()
|
examples/android/utils.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
|
3 |
+
import json
|
4 |
+
import traceback
|
5 |
+
import uuid
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any, Optional, Dict, List
|
9 |
+
|
10 |
+
from langchain_core.load import dumpd, load
|
11 |
+
from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, SystemMessage, HumanMessage
|
12 |
+
from openai import RateLimitError
|
13 |
+
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
|
14 |
+
|
15 |
+
from aworld.core.agent.base import AgentResult
|
16 |
+
from aworld.core.common import ActionResult, Observation
|
17 |
+
|
18 |
+
|
19 |
+
class MessageMetadata(BaseModel):
|
20 |
+
"""Metadata for a message"""
|
21 |
+
|
22 |
+
tokens: int = 0
|
23 |
+
|
24 |
+
|
25 |
+
class ManagedMessage(BaseModel):
|
26 |
+
"""A message with its metadata"""
|
27 |
+
|
28 |
+
message: BaseMessage
|
29 |
+
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
30 |
+
|
31 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
32 |
+
|
33 |
+
# https://github.com/pydantic/pydantic/discussions/7558
|
34 |
+
@model_serializer(mode='wrap')
|
35 |
+
def to_json(self, original_dump):
|
36 |
+
"""
|
37 |
+
Returns the JSON representation of the model.
|
38 |
+
|
39 |
+
It uses langchain's `dumps` function to serialize the `message`
|
40 |
+
property before encoding the overall dict with json.dumps.
|
41 |
+
"""
|
42 |
+
data = original_dump(self)
|
43 |
+
|
44 |
+
# NOTE: We override the message field to use langchain JSON serialization.
|
45 |
+
data['message'] = dumpd(self.message)
|
46 |
+
|
47 |
+
return data
|
48 |
+
|
49 |
+
@model_validator(mode='before')
|
50 |
+
@classmethod
|
51 |
+
def validate(
|
52 |
+
cls,
|
53 |
+
value: Any,
|
54 |
+
*,
|
55 |
+
strict: bool | None = None,
|
56 |
+
from_attributes: bool | None = None,
|
57 |
+
context: Any | None = None,
|
58 |
+
) -> Any:
|
59 |
+
"""
|
60 |
+
Custom validator that uses langchain's `loads` function
|
61 |
+
to parse the message if it is provided as a JSON string.
|
62 |
+
"""
|
63 |
+
if isinstance(value, dict) and 'message' in value:
|
64 |
+
# NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object.
|
65 |
+
value['message'] = load(value['message'])
|
66 |
+
return value
|
67 |
+
|
68 |
+
|
69 |
+
class MessageHistory(BaseModel):
|
70 |
+
"""History of messages with metadata"""
|
71 |
+
|
72 |
+
messages: list[ManagedMessage] = Field(default_factory=list)
|
73 |
+
current_tokens: int = 0
|
74 |
+
|
75 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
76 |
+
|
77 |
+
def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None:
|
78 |
+
"""Add message with metadata to history"""
|
79 |
+
if position is None:
|
80 |
+
self.messages.append(ManagedMessage(message=message, metadata=metadata))
|
81 |
+
else:
|
82 |
+
self.messages.insert(position, ManagedMessage(message=message, metadata=metadata))
|
83 |
+
self.current_tokens += metadata.tokens
|
84 |
+
|
85 |
+
def add_model_output(self, output) -> None:
|
86 |
+
"""Add model output as AI message"""
|
87 |
+
tool_calls = [
|
88 |
+
{
|
89 |
+
'name': 'AgentOutput',
|
90 |
+
'args': output.model_dump(mode='json', exclude_unset=True),
|
91 |
+
'id': '1',
|
92 |
+
'type': 'tool_call',
|
93 |
+
}
|
94 |
+
]
|
95 |
+
|
96 |
+
msg = AIMessage(
|
97 |
+
content='',
|
98 |
+
tool_calls=tool_calls,
|
99 |
+
)
|
100 |
+
self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls
|
101 |
+
|
102 |
+
# Empty tool response
|
103 |
+
tool_message = ToolMessage(content='', tool_call_id='1')
|
104 |
+
self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response
|
105 |
+
|
106 |
+
def get_messages(self) -> list[BaseMessage]:
|
107 |
+
"""Get all messages"""
|
108 |
+
return [m.message for m in self.messages]
|
109 |
+
|
110 |
+
def get_total_tokens(self) -> int:
|
111 |
+
"""Get total tokens in history"""
|
112 |
+
return self.current_tokens
|
113 |
+
|
114 |
+
def remove_oldest_message(self) -> None:
|
115 |
+
"""Remove oldest non-system message"""
|
116 |
+
for i, msg in enumerate(self.messages):
|
117 |
+
if not isinstance(msg.message, SystemMessage):
|
118 |
+
self.current_tokens -= msg.metadata.tokens
|
119 |
+
self.messages.pop(i)
|
120 |
+
break
|
121 |
+
|
122 |
+
def remove_last_state_message(self) -> None:
|
123 |
+
"""Remove last state message from history"""
|
124 |
+
if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage):
|
125 |
+
self.current_tokens -= self.messages[-1].metadata.tokens
|
126 |
+
self.messages.pop()
|
127 |
+
|
128 |
+
|
129 |
+
class MessageManagerState(BaseModel):
|
130 |
+
"""Holds the state for MessageManager"""
|
131 |
+
|
132 |
+
history: MessageHistory = Field(default_factory=MessageHistory)
|
133 |
+
tool_id: int = 1
|
134 |
+
|
135 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
136 |
+
|
137 |
+
|
138 |
+
class AgentSettings(BaseModel):
|
139 |
+
"""Options for the agent"""
|
140 |
+
max_failures: int = 3
|
141 |
+
retry_delay: int = 10
|
142 |
+
save_history: bool = True
|
143 |
+
history_path: Optional[str] = None
|
144 |
+
max_actions_per_step: int = 10
|
145 |
+
validate_output: bool = False
|
146 |
+
message_context: Optional[str] = None
|
147 |
+
|
148 |
+
|
149 |
+
class PolicyMetadata(BaseModel):
|
150 |
+
"""Metadata for a single step including timing information"""
|
151 |
+
start_time: float
|
152 |
+
end_time: float
|
153 |
+
number: int
|
154 |
+
input_tokens: int
|
155 |
+
|
156 |
+
@property
|
157 |
+
def duration_seconds(self) -> float:
|
158 |
+
"""Calculate step duration in seconds"""
|
159 |
+
return self.end_time - self.start_time
|
160 |
+
|
161 |
+
|
162 |
+
class AgentBrain(BaseModel):
|
163 |
+
"""Current state of the agent"""
|
164 |
+
evaluation_previous_goal: str
|
165 |
+
memory: str
|
166 |
+
next_goal: str
|
167 |
+
|
168 |
+
|
169 |
+
class AgentHistory(BaseModel):
|
170 |
+
"""History item for agent actions"""
|
171 |
+
model_output: Optional[BaseModel] = None
|
172 |
+
result: List[ActionResult]
|
173 |
+
metadata: Optional[PolicyMetadata] = None
|
174 |
+
content: Optional[str] = None
|
175 |
+
base64_img: Optional[str] = None
|
176 |
+
|
177 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
178 |
+
|
179 |
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
180 |
+
"""Custom serialization handling"""
|
181 |
+
return {
|
182 |
+
'model_output': self.model_output.model_dump() if self.model_output else None,
|
183 |
+
'result': [r.model_dump(exclude_none=True) for r in self.result],
|
184 |
+
'metadata': self.metadata.model_dump() if self.metadata else None,
|
185 |
+
'content': self.xml_content,
|
186 |
+
'base64_img': self.base64_img
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
class AgentHistoryList(BaseModel):
|
191 |
+
"""List of agent history items"""
|
192 |
+
history: List[AgentHistory]
|
193 |
+
|
194 |
+
def total_duration_seconds(self) -> float:
|
195 |
+
"""Get total duration of all steps in seconds"""
|
196 |
+
total = 0.0
|
197 |
+
for h in self.history:
|
198 |
+
if h.metadata:
|
199 |
+
total += h.metadata.duration_seconds
|
200 |
+
return total
|
201 |
+
|
202 |
+
def save_to_file(self, filepath: str | Path) -> None:
|
203 |
+
"""Save history to JSON file with proper serialization"""
|
204 |
+
try:
|
205 |
+
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
206 |
+
data = self.model_dump()
|
207 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
208 |
+
json.dump(data, f, indent=2)
|
209 |
+
except Exception as e:
|
210 |
+
raise e
|
211 |
+
|
212 |
+
def model_dump(self, **kwargs) -> Dict[str, Any]:
|
213 |
+
"""Custom serialization that properly uses AgentHistory's model_dump"""
|
214 |
+
return {
|
215 |
+
'history': [h.model_dump(**kwargs) for h in self.history],
|
216 |
+
}
|
217 |
+
|
218 |
+
@classmethod
|
219 |
+
def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList':
|
220 |
+
"""Load history from JSON file"""
|
221 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
222 |
+
data = json.load(f)
|
223 |
+
return cls.model_validate(data)
|
224 |
+
|
225 |
+
|
226 |
+
class AgentError:
|
227 |
+
"""Container for agent error handling"""
|
228 |
+
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
|
229 |
+
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
|
230 |
+
NO_VALID_ACTION = 'No valid action found'
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def format_error(error: Exception, include_trace: bool = False) -> str:
|
234 |
+
"""Format error message based on error type and optionally include trace"""
|
235 |
+
if isinstance(error, RateLimitError):
|
236 |
+
return AgentError.RATE_LIMIT_ERROR
|
237 |
+
if include_trace:
|
238 |
+
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
|
239 |
+
return f'{str(error)}'
|
240 |
+
|
241 |
+
|
242 |
+
class AgentState(BaseModel):
|
243 |
+
"""Holds all state information for an Agent"""
|
244 |
+
|
245 |
+
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
246 |
+
n_steps: int = 1
|
247 |
+
consecutive_failures: int = 0
|
248 |
+
last_result: Optional[List['ActionResult']] = None
|
249 |
+
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
|
250 |
+
last_plan: Optional[str] = None
|
251 |
+
paused: bool = False
|
252 |
+
stopped: bool = False
|
253 |
+
message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState)
|
254 |
+
|
255 |
+
|
256 |
+
@dataclass
|
257 |
+
class AgentStepInfo:
|
258 |
+
number: int
|
259 |
+
max_steps: int
|
260 |
+
|
261 |
+
def is_last_step(self) -> bool:
|
262 |
+
"""Check if this is the last step"""
|
263 |
+
return self.number >= self.max_steps - 1
|
264 |
+
|
265 |
+
|
266 |
+
@dataclass
|
267 |
+
class Trajectory:
|
268 |
+
"""Stores the agent's history, including all observations, info, and AgentResults."""
|
269 |
+
history: List[tuple[Observation, Dict[str, Any], AgentResult]] = field(default_factory=list)
|
270 |
+
|
271 |
+
def add_step(self, observation: Observation, info: Dict[str, Any], agent_result: AgentResult):
|
272 |
+
"""Add a step to the history"""
|
273 |
+
self.history.append((observation, info, agent_result))
|
274 |
+
|
275 |
+
def get_history(self) -> List[tuple[Observation, Dict[str, Any], AgentResult]]:
|
276 |
+
"""Retrieve the complete history"""
|
277 |
+
return self.history
|