Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
AWorld-main/aworlddistributed/aworldspace/agents/gaia_agent.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, Any, List, Optional
|
6 |
+
|
7 |
+
from aworld.config import ModelConfig
|
8 |
+
from aworld.config.conf import AgentConfig, TaskConfig, ClientType
|
9 |
+
from aworld.core.task import Task
|
10 |
+
from aworld.output import Outputs, Output, StreamingOutputs
|
11 |
+
from aworld.utils.common import get_local_ip
|
12 |
+
from datasets import load_dataset, concatenate_datasets
|
13 |
+
from pydantic import BaseModel, Field
|
14 |
+
|
15 |
+
from aworldspace.base_agent import AworldBaseAgent
|
16 |
+
from aworldspace.utils.mcp_utils import load_all_mcp_config
|
17 |
+
from aworldspace.utils.utils import question_scorer
|
18 |
+
|
19 |
+
GAIA_SYSTEM_PROMPT = f"""You are an all-capable AI assistant, aimed at solving any task presented by the user. You have various tools at your disposal that you can call upon to efficiently complete complex requests. Whether it's programming, information retrieval, file processing, or web browsing, you can handle it all.
|
20 |
+
Please note that the task may be complex. Do not attempt to solve it all at once. You should break the task down and use different tools step by step to solve it. After using each tool, clearly explain the execution results and suggest the next steps.
|
21 |
+
Please utilize appropriate tools for the task, analyze the results obtained from these tools, and provide your reasoning. Always use available tools such as browser, calcutor, etc. to verify correctness rather than relying on your internal knowledge.
|
22 |
+
If you believe the problem has been solved, please output the `final answer`. The `final answer` should be given in <answer></answer> format, while your other thought process should be output in <think></think> tags.
|
23 |
+
Your `final answer` should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
24 |
+
|
25 |
+
Here are some tips to help you give better instructions:
|
26 |
+
<tips>
|
27 |
+
1. Do not use any tools outside of the provided tools list.
|
28 |
+
2. Even if the task is complex, there is always a solution. If you can’t find the answer using one method, try another approach or use different tools to find the solution.
|
29 |
+
3. When using browser `playwright_click` tool, you need to check if the element exists and is clickable before clicking it.
|
30 |
+
4. Before providing the `final answer`, carefully reflect on whether the task has been fully solved. If you have not solved the task, please provide your reasoning and suggest the next steps.
|
31 |
+
5. Due to context length limitations, always try to complete browser-based tasks with the minimal number of steps possible.
|
32 |
+
6. When providing the `final answer`, answer the user's question directly and precisely. For example, if asked "what animal is x?" and x is a monkey, simply answer "monkey" rather than "x is a monkey".
|
33 |
+
7. When you need to process excel file, prioritize using the `excel` tool instead of writing custom code with `terminal-controller` tool.
|
34 |
+
8. If you need to download a file, please use the `terminal-controller` tool to download the file and save it to the specified path.
|
35 |
+
9. The browser doesn't support direct searching on www.google.com. Use the `google-search` to get the relevant website URLs or contents instead of `ms-playwright` directly.
|
36 |
+
10. Always use only one tool at a time in each step of your execution.
|
37 |
+
11. Using `mcp__ms-playwright__browser_pdf_save` tool to save the pdf file of URLs to the specified path.
|
38 |
+
12. Using `mcp__terminal-controller__execute_command` tool to set the timeout to 300 seconds when downloading large files such as pdf.
|
39 |
+
13. Using `mcp__ms-playwright__browser_take_screenshot` tool to save the screenshot of URLs to the specified path when you need to understand the gif / jpg of the URLs.
|
40 |
+
14. When there are questions related to YouTube video comprehension, use tools in `youtube_download_server` and `video_server` to analyze the video content by the given question.
|
41 |
+
</tips>
|
42 |
+
|
43 |
+
Now, here is the task. Stay focused and complete it carefully using the appropriate tools!
|
44 |
+
"""
|
45 |
+
|
46 |
+
class Pipeline(AworldBaseAgent):
|
47 |
+
class Valves(BaseModel):
|
48 |
+
llm_provider: Optional[str] = Field(default=None, description="llm_model_name")
|
49 |
+
llm_model_name: Optional[str] = Field(default=None, description="llm_model_name")
|
50 |
+
llm_base_url: Optional[str] = Field(default=None,description="llm_base_urly")
|
51 |
+
llm_api_key: Optional[str] = Field(default=None,description="llm api key" )
|
52 |
+
system_prompt: str = Field(default=GAIA_SYSTEM_PROMPT,description="system_prompt")
|
53 |
+
history_messages: int = Field(default=100, description="rounds of history messages")
|
54 |
+
|
55 |
+
def __init__(self):
|
56 |
+
self.valves = self.Valves()
|
57 |
+
self.gaia_files = os.path.abspath(os.path.join(os.path.curdir, "aworldspace", "datasets", "gaia_dataset"))
|
58 |
+
logging.info(f"gaia_files path {self.gaia_files}")
|
59 |
+
self.full_dataset = load_dataset(
|
60 |
+
os.path.join(self.gaia_files, "GAIA.py"),
|
61 |
+
name="2023_all",
|
62 |
+
trust_remote_code=True
|
63 |
+
)
|
64 |
+
self.full_dataset = concatenate_datasets([self.full_dataset['validation'], self.full_dataset['test']])
|
65 |
+
|
66 |
+
# Create task_id to index mapping for improved lookup performance
|
67 |
+
self.task_id_to_index = {}
|
68 |
+
for i, task in enumerate(self.full_dataset):
|
69 |
+
self.task_id_to_index[task['task_id']] = i
|
70 |
+
|
71 |
+
logging.info(f"Loaded {len(self.full_dataset)} tasks, created task_id mapping")
|
72 |
+
logging.info("gaia_agent init success")
|
73 |
+
|
74 |
+
async def get_custom_input(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Any:
|
75 |
+
task = await self.get_gaia_task(user_message)
|
76 |
+
logging.info(f"🌈 -----------------------------------------------")
|
77 |
+
logging.info(f"🚀 Start to process: gaia_task_{task['task_id']}")
|
78 |
+
logging.info(f"📝 Detail: {task}")
|
79 |
+
logging.info(f"❓ Question: {task['Question']}")
|
80 |
+
logging.info(f"⭐ Level: {task['Level']}")
|
81 |
+
logging.info(f"🛠️ Tools: {task['Annotator Metadata']['Tools']}")
|
82 |
+
logging.info(f"🌈 -----------------------------------------------")
|
83 |
+
return task['Question']
|
84 |
+
|
85 |
+
async def get_agent_config(self, body):
|
86 |
+
default_llm_provider = self.valves.llm_provider if self.valves.llm_provider else os.environ.get("LLM_PROVIDER")
|
87 |
+
llm_model_name = self.valves.llm_model_name if self.valves.llm_model_name else os.environ.get("LLM_MODEL_NAME")
|
88 |
+
llm_api_key = self.valves.llm_api_key if self.valves.llm_api_key else os.environ.get("LLM_API_KEY")
|
89 |
+
llm_base_url = self.valves.llm_base_url if self.valves.llm_base_url else os.environ.get("LLM_BASE_URL")
|
90 |
+
system_prompt = self.valves.system_prompt if self.valves.system_prompt else GAIA_SYSTEM_PROMPT
|
91 |
+
|
92 |
+
task = await self.get_task_from_body(body)
|
93 |
+
if task:
|
94 |
+
logging.info(f"task llm config is: {task.llm_provider}, {task.llm_model_name}, {task.llm_api_key}, {task.llm_base_url}")
|
95 |
+
|
96 |
+
llm_config = ModelConfig(
|
97 |
+
llm_provider=task.llm_provider if task and task.llm_provider else default_llm_provider,
|
98 |
+
llm_model_name=task.llm_model_name if task and task.llm_model_name else llm_model_name,
|
99 |
+
llm_api_key=task.llm_api_key if task and task.llm_api_key else llm_api_key,
|
100 |
+
llm_base_url=task.llm_base_url if task and task.llm_base_url else llm_base_url,
|
101 |
+
max_retries=task.max_retries if task and task.max_retries else 3
|
102 |
+
)
|
103 |
+
|
104 |
+
return AgentConfig(
|
105 |
+
name=self.agent_name(),
|
106 |
+
llm_config=llm_config,
|
107 |
+
system_prompt=task.task_system_prompt if task and task.task_system_prompt else system_prompt
|
108 |
+
)
|
109 |
+
|
110 |
+
def agent_name(self) -> str:
|
111 |
+
return "GaiaAgent"
|
112 |
+
|
113 |
+
async def get_mcp_servers(self, body) -> list[str]:
|
114 |
+
task = await self.get_task_from_body(body)
|
115 |
+
if task and task.mcp_servers:
|
116 |
+
logging.info(f"mcp_servers from task: {task.mcp_servers}")
|
117 |
+
return task.mcp_servers
|
118 |
+
|
119 |
+
return [
|
120 |
+
"e2b-server",
|
121 |
+
"terminal-controller",
|
122 |
+
"excel",
|
123 |
+
"calculator",
|
124 |
+
"ms-playwright",
|
125 |
+
"audio_server",
|
126 |
+
"image_server",
|
127 |
+
"video_server",
|
128 |
+
"search_server",
|
129 |
+
"download_server",
|
130 |
+
"document_server",
|
131 |
+
"youtube_server",
|
132 |
+
"reasoning_server",
|
133 |
+
]
|
134 |
+
|
135 |
+
async def get_gaia_task(self, task_id: str) -> dict:
|
136 |
+
"""
|
137 |
+
Get GAIA task by task_id
|
138 |
+
Args:
|
139 |
+
task_id: Unique identifier of the task
|
140 |
+
Returns:
|
141 |
+
Corresponding task dictionary
|
142 |
+
"""
|
143 |
+
|
144 |
+
# Search by task_id
|
145 |
+
if task_id in self.task_id_to_index:
|
146 |
+
index = self.task_id_to_index[task_id]
|
147 |
+
gaia_task = self.full_dataset[index]
|
148 |
+
else:
|
149 |
+
raise ValueError(f"Task with task_id '{task_id}' not found in dataset")
|
150 |
+
|
151 |
+
return self.add_file_path(gaia_task)
|
152 |
+
|
153 |
+
def get_all_task_ids(self) -> List[str]:
|
154 |
+
"""
|
155 |
+
Get list of all available task_ids
|
156 |
+
Returns:
|
157 |
+
List of all task_ids
|
158 |
+
"""
|
159 |
+
return list(self.task_id_to_index.keys())
|
160 |
+
|
161 |
+
def get_task_count(self) -> int:
|
162 |
+
"""
|
163 |
+
Get total number of tasks
|
164 |
+
Returns:
|
165 |
+
Total task count
|
166 |
+
"""
|
167 |
+
return len(self.full_dataset)
|
168 |
+
|
169 |
+
def get_task_index_by_id(self, task_id: str) -> int:
|
170 |
+
"""
|
171 |
+
Get task index in dataset by task_id
|
172 |
+
Args:
|
173 |
+
task_id: Unique identifier of the task
|
174 |
+
Returns:
|
175 |
+
Index of the task in the dataset
|
176 |
+
"""
|
177 |
+
if task_id in self.task_id_to_index:
|
178 |
+
return self.task_id_to_index[task_id]
|
179 |
+
else:
|
180 |
+
raise ValueError(f"Task with task_id '{task_id}' not found in dataset")
|
181 |
+
|
182 |
+
async def custom_output_before_task(self, outputs: Outputs, chat_id: str, task: Task) -> None:
|
183 |
+
task_config:TaskConfig = task.conf
|
184 |
+
gaia_task = await self.get_gaia_task(task_config.ext['origin_message'])
|
185 |
+
|
186 |
+
result = f"\n\n`{get_local_ip()}` execute `GAIA TASK#{task_config.ext['origin_message']}`:\n\n---\n\n"
|
187 |
+
result += f"**Question**: {gaia_task['Question']}\n"
|
188 |
+
result += f"**Answer**: {gaia_task['Final answer']}\n"
|
189 |
+
result += f"**Level**: {gaia_task['Level']}\n"
|
190 |
+
result += f"**Tools**: \n {gaia_task['Annotator Metadata']['Tools']}\n"
|
191 |
+
result += f"\n\n-----\n\n"
|
192 |
+
await outputs.add_output(Output(data = result))
|
193 |
+
|
194 |
+
async def custom_output_after_task(self, outputs: Outputs, chat_id: str, task: Task):
|
195 |
+
"""
|
196 |
+
check gaia task output
|
197 |
+
Args:
|
198 |
+
outputs:
|
199 |
+
chat_id:
|
200 |
+
task:
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
|
204 |
+
"""
|
205 |
+
task_config: TaskConfig = task.conf
|
206 |
+
gaia_task_id = task_config['ext']['origin_message']
|
207 |
+
gaia_task = await self.get_gaia_task(gaia_task_id)
|
208 |
+
agent_result = ""
|
209 |
+
if isinstance(outputs, StreamingOutputs):
|
210 |
+
agent_result = await outputs._visited_outputs[-2].get_finished_response() # read llm result
|
211 |
+
match = re.search(r"<answer>(.*?)</answer>", agent_result)
|
212 |
+
answer = agent_result
|
213 |
+
if match:
|
214 |
+
answer = match.group(1)
|
215 |
+
|
216 |
+
logging.info(f"🤖 Agent answer: {answer}")
|
217 |
+
logging.info(f"👨🏫 Correct answer: {gaia_task['Final answer']}")
|
218 |
+
is_correct = question_scorer(answer, gaia_task["Final answer"])
|
219 |
+
|
220 |
+
if is_correct:
|
221 |
+
logging.info(f"📝Question {gaia_task_id} Correct! 🎉")
|
222 |
+
result = f"\n\n📝 **Question: {gaia_task_id} -> Agent Answer:[{answer}] is `Correct`**"
|
223 |
+
else:
|
224 |
+
logging.info(f"📝Question {gaia_task_id} Incorrect! ❌")
|
225 |
+
result = f"\n\n📝 **Question: {gaia_task_id} -> Agent Answer:`{answer}` != Correct answer: `{gaia_task['Final answer']}` is `Incorrect` ❌**"
|
226 |
+
|
227 |
+
metadata = await outputs.get_metadata()
|
228 |
+
if not metadata:
|
229 |
+
await outputs.set_metadata({})
|
230 |
+
metadata = await outputs.get_metadata()
|
231 |
+
metadata['gaia_correct'] = is_correct
|
232 |
+
metadata['gaia_result'] = result
|
233 |
+
metadata['agent_answer'] = answer
|
234 |
+
metadata['correct_answer'] = gaia_task['Final answer']
|
235 |
+
return result
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
def add_file_path(self, task: Dict[str, Any]
|
240 |
+
):
|
241 |
+
split = "validation" if task["Annotator Metadata"]["Steps"] != "" else "test"
|
242 |
+
|
243 |
+
if task["file_name"]:
|
244 |
+
file_path = Path(f"{self.gaia_files}/2023/{split}/" + task["file_name"])
|
245 |
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
246 |
+
task["Question"] += f" Here are the necessary document files: {file_path}"
|
247 |
+
|
248 |
+
elif file_path.suffix in [".jpg", ".jpeg", ".png"]:
|
249 |
+
task["Question"] += f" Here are the necessary image files: {file_path}"
|
250 |
+
|
251 |
+
elif file_path.suffix in [".xlsx", "xls", ".csv"]:
|
252 |
+
task[
|
253 |
+
"Question"
|
254 |
+
] += f" Here are the necessary table files: {file_path}, for processing excel file, you can use the excel tool or write python code to process the file step-by-step and get the information."
|
255 |
+
|
256 |
+
elif file_path.suffix in [".py"]:
|
257 |
+
task["Question"] += f" Here are the necessary python files: {file_path}"
|
258 |
+
|
259 |
+
else:
|
260 |
+
task["Question"] += f" Here are the necessary files: {file_path}"
|
261 |
+
|
262 |
+
return task
|
263 |
+
async def load_mcp_config(self) -> dict:
|
264 |
+
return load_all_mcp_config()
|
AWorld-main/aworlddistributed/aworldspace/agents/playwright_agent.py
ADDED
@@ -0,0 +1,792 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import traceback
|
4 |
+
from typing import Dict, Any, List, Union
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from aworldspace.base_agent import AworldBaseAgent
|
8 |
+
from pydantic import BaseModel, Field
|
9 |
+
|
10 |
+
import aworld.trace as trace
|
11 |
+
from aworld.config.conf import AgentConfig, ConfigDict
|
12 |
+
from aworld.config.conf import TaskConfig
|
13 |
+
from aworld.agents.llm_agent import Agent
|
14 |
+
from aworld.core.common import Observation, ActionModel
|
15 |
+
from aworld.core.memory import MemoryItem
|
16 |
+
from aworld.core.task import Task
|
17 |
+
from aworld.logs.util import logger
|
18 |
+
from aworld.models.llm import acall_llm_model
|
19 |
+
from aworld.models.model_response import ToolCall, Function
|
20 |
+
from aworld.output import Output, StreamingOutputs
|
21 |
+
from aworld.output import Outputs
|
22 |
+
from aworld.output.base import MessageOutput
|
23 |
+
from aworld.utils.common import sync_exec
|
24 |
+
|
25 |
+
BROWSER_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
26 |
+
## Output Format
|
27 |
+
```
|
28 |
+
Thought: ...
|
29 |
+
Action: ...
|
30 |
+
```
|
31 |
+
## Action Space
|
32 |
+
navigate(website='xxx') #Open the target website, usually the first action to open browser.
|
33 |
+
click(start_box='[x1, y1, x2, y2]')
|
34 |
+
left_double(start_box='[x1, y1, x2, y2]')
|
35 |
+
right_single(start_box='[x1, y1, x2, y2]')
|
36 |
+
drag(start_box='[x1, y1, x2, y2]', end_box='[x3, y3, x4, y4]')
|
37 |
+
hotkey(key='')
|
38 |
+
type(content='') #If you want to submit your input, use "\n" at the end of `content`.
|
39 |
+
scroll(direction='down or up or right or left')
|
40 |
+
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
41 |
+
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
42 |
+
## Note
|
43 |
+
- only one action per step.
|
44 |
+
- Use Chinese in `Thought` part.
|
45 |
+
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
46 |
+
## User Instruction
|
47 |
+
"""
|
48 |
+
|
49 |
+
import json
|
50 |
+
import re
|
51 |
+
|
52 |
+
MAX_IMAGE = 50
|
53 |
+
|
54 |
+
|
55 |
+
def parse_action_output(output_text):
|
56 |
+
# 提取Thought部分
|
57 |
+
logger.info(f"{output_text=}")
|
58 |
+
thought_match = re.search(r'Thought:(.*?)\nAction:', output_text, re.DOTALL)
|
59 |
+
thought = thought_match.group(1).strip() if thought_match else ""
|
60 |
+
|
61 |
+
# 提取Action部分
|
62 |
+
action_match = re.search(r'Action:(.*?)(?:\n|$)', output_text, re.DOTALL)
|
63 |
+
action_text = action_match.group(1).strip() if action_match else ""
|
64 |
+
|
65 |
+
# 初始化结果字典
|
66 |
+
result = {
|
67 |
+
"thought": thought,
|
68 |
+
"action": "",
|
69 |
+
"key": None,
|
70 |
+
"content": None,
|
71 |
+
"start_box": None,
|
72 |
+
"end_box": None,
|
73 |
+
"direction": None,
|
74 |
+
"website": None,
|
75 |
+
}
|
76 |
+
|
77 |
+
if not action_text:
|
78 |
+
return json.dumps(result, ensure_ascii=False)
|
79 |
+
|
80 |
+
# tmp 兼容ui-tars1.5-7b
|
81 |
+
action_text = action_text.replace("'(","'[").replace(")'","]'")
|
82 |
+
|
83 |
+
# 解析action类型
|
84 |
+
action_parts = action_text.split('(')
|
85 |
+
action_type = action_parts[0]
|
86 |
+
result["action"] = action_type
|
87 |
+
|
88 |
+
# 解析参数
|
89 |
+
if len(action_parts) > 1:
|
90 |
+
params_text = action_parts[1].rstrip(')')
|
91 |
+
params = {}
|
92 |
+
|
93 |
+
# gpt-4o兼容
|
94 |
+
if 'start_box' in params_text:
|
95 |
+
params_text = params_text.replace(", ", " ").replace(",", " ")
|
96 |
+
if 'end_box' in params_text:
|
97 |
+
params_text = params_text.replace(" end_box", ", end_box")
|
98 |
+
|
99 |
+
# 处理键值对参数
|
100 |
+
for param in params_text.split(','):
|
101 |
+
param = param.strip()
|
102 |
+
|
103 |
+
if '=' in param:
|
104 |
+
key, value = param.split('=', 1)
|
105 |
+
key = key.strip()
|
106 |
+
value = value.strip().strip('\'"')
|
107 |
+
|
108 |
+
# 处理bbox格式
|
109 |
+
if 'box' in key:
|
110 |
+
print(value)
|
111 |
+
# 提取坐标数字
|
112 |
+
numbers = re.findall(r'\d+', value)
|
113 |
+
print(numbers)
|
114 |
+
if numbers:
|
115 |
+
coords = [int(num) for num in numbers]
|
116 |
+
if len(coords) == 4:
|
117 |
+
if key == 'start_box':
|
118 |
+
result["start_box"] = coords
|
119 |
+
elif key == 'end_box':
|
120 |
+
result["end_box"] = coords
|
121 |
+
if len(coords) == 2:
|
122 |
+
if key == 'start_box':
|
123 |
+
result["start_box"] = [coords[0], coords[1], coords[0], coords[1]]
|
124 |
+
elif key == 'end_box':
|
125 |
+
result["end_box"] = [coords[0], coords[1], coords[0], coords[1]]
|
126 |
+
elif key == 'key':
|
127 |
+
result["key"] = value.replace("pagedown", "PageDown").replace("pageup", "PageUp").replace("enter","Enter")
|
128 |
+
elif key == 'content':
|
129 |
+
# 处理转义字符
|
130 |
+
value = value.replace('\\n', '\n').replace('\\"', '"').replace("\\'", "'")
|
131 |
+
result["content"] = value
|
132 |
+
elif key == 'website':
|
133 |
+
result["website"] = value
|
134 |
+
elif key == 'direction':
|
135 |
+
result["direction"] = value
|
136 |
+
|
137 |
+
return result, thought, action_text
|
138 |
+
|
139 |
+
|
140 |
+
def parse_tool_call(line):
|
141 |
+
# 提取 Action和param
|
142 |
+
result, thought, action_text = parse_action_output(line)
|
143 |
+
action = result['action']
|
144 |
+
|
145 |
+
# 映射到实际函数名和参数
|
146 |
+
if action == 'navigate':
|
147 |
+
func_name = 'mcp__ms-playwright__browser_navigate'
|
148 |
+
content = {'url': result['website']}
|
149 |
+
|
150 |
+
elif action == 'click':
|
151 |
+
func_name = 'mcp__ms-playwright__browser_screen_click'
|
152 |
+
|
153 |
+
x = int((result["start_box"][0] + result["start_box"][2]) / 2)
|
154 |
+
y = int((result["start_box"][1] + result["start_box"][3]) / 2)
|
155 |
+
content = {'element': '', 'x': x, 'y': y}
|
156 |
+
|
157 |
+
elif action == 'right_single':
|
158 |
+
func_name = 'mcp__ms-playwright__browser_screen_click'
|
159 |
+
|
160 |
+
x = int((result["start_box"][0] + result["start_box"][2]) / 2)
|
161 |
+
y = int((result["start_box"][1] + result["start_box"][3]) / 2)
|
162 |
+
content = {'element': 'right click target', 'x': x, 'y': y, 'button': 'right'}
|
163 |
+
|
164 |
+
elif action == 'drag':
|
165 |
+
func_name = 'mcp__ms-playwright__browser_screen_drag'
|
166 |
+
|
167 |
+
x1 = int((result["start_box"][0] + result["start_box"][2]) / 2)
|
168 |
+
y1 = int((result["start_box"][1] + result["start_box"][3]) / 2)
|
169 |
+
x2 = int((result["end_box"][0] + result["end_box"][2]) / 2)
|
170 |
+
y2 = int((result["end_box"][1] + result["end_box"][3]) / 2)
|
171 |
+
|
172 |
+
content = {
|
173 |
+
'element': f'drag from [{x1},{y1}] to [{x2},{y2}]',
|
174 |
+
'startX': x1,
|
175 |
+
'startY': y1,
|
176 |
+
'endX': x2,
|
177 |
+
'endY': y2
|
178 |
+
}
|
179 |
+
elif action == 'hotkey':
|
180 |
+
func_name = 'mcp__ms-playwright__browser_press_key'
|
181 |
+
content = {'key': result["key"]}
|
182 |
+
elif action == 'type':
|
183 |
+
func_name = 'mcp__ms-playwright__browser_screen_type'
|
184 |
+
content = {'text': result['content']}
|
185 |
+
elif action == 'scroll':
|
186 |
+
# 暂时使用presskey代替scroll
|
187 |
+
func_name = 'mcp__ms-playwright__browser_press_key'
|
188 |
+
direction = result['direction']
|
189 |
+
key_map = {
|
190 |
+
'up': 'PageUp',
|
191 |
+
'down': 'PageDown',
|
192 |
+
'left': 'ArrowLeft',
|
193 |
+
'right': 'ArrowRight'
|
194 |
+
}
|
195 |
+
key = key_map.get(direction, 'ArrowDown')
|
196 |
+
content = {'key': key}
|
197 |
+
elif action == 'wait':
|
198 |
+
func_name = 'mcp__ms-playwright__browser_wait_for'
|
199 |
+
content = {'time': 5}
|
200 |
+
elif action == 'finished':
|
201 |
+
func_name = "finished"
|
202 |
+
content = result['content']
|
203 |
+
else:
|
204 |
+
return ""
|
205 |
+
|
206 |
+
return Function(name=func_name, arguments=json.dumps(content)), thought, action_text, result
|
207 |
+
|
208 |
+
# eval code start
|
209 |
+
|
210 |
+
|
211 |
+
def identify_key_points(task):
|
212 |
+
system_msg = """You are an expert tasked with analyzing a given task to identify the key points explicitly stated in the task description.
|
213 |
+
|
214 |
+
**Objective**: Carefully analyze the task description and extract the critical elements explicitly mentioned in the task for achieving its goal.
|
215 |
+
|
216 |
+
**Instructions**:
|
217 |
+
1. Read the task description carefully.
|
218 |
+
2. Identify and extract **key points** directly stated in the task description.
|
219 |
+
- A **key point** is a critical element, condition, or step explicitly mentioned in the task description.
|
220 |
+
- Do not infer or add any unstated elements.
|
221 |
+
- Words such as "best," "highest," "cheapest," "latest," "most recent," "lowest," "closest," "highest-rated," "largest," and "newest" must go through the sort function(e.g., the key point should be "Filter by highest").
|
222 |
+
|
223 |
+
**Respond with**:
|
224 |
+
- **Key Points**: A numbered list of the explicit key points for completing this task, one per line, without explanations or additional details."""
|
225 |
+
prompt = """Task: {task}"""
|
226 |
+
text = prompt.format(task=task)
|
227 |
+
messages = [
|
228 |
+
{"role": "system", "content": system_msg},
|
229 |
+
{
|
230 |
+
"role": "user",
|
231 |
+
"content": [
|
232 |
+
{"type": "text", "text": text}
|
233 |
+
],
|
234 |
+
}
|
235 |
+
]
|
236 |
+
return messages
|
237 |
+
|
238 |
+
|
239 |
+
def judge_image(task, image_path, key_points):
|
240 |
+
system_msg = """You are an expert evaluator tasked with determining whether an image contains information about the necessary steps to complete a task.
|
241 |
+
|
242 |
+
**Objective**: Analyze the provided image and decide if it shows essential steps or evidence required for completing the task. Use your reasoning to explain your decision before assigning a score.
|
243 |
+
|
244 |
+
**Instructions**:
|
245 |
+
1. Provide a detailed description of the image, including its contents, visible elements, text (if any), and any notable features.
|
246 |
+
|
247 |
+
2. Carefully examine the image and evaluate whether it contains necessary steps or evidence crucial to task completion:
|
248 |
+
- Identify key points that could be relevant to task completion, such as actions, progress indicators, tool usage, applied filters, or step-by-step instructions.
|
249 |
+
- Does the image show actions, progress indicators, or critical information directly related to completing the task?
|
250 |
+
- Is this information indispensable for understanding or ensuring task success?
|
251 |
+
- If the image contains partial but relevant information, consider its usefulness rather than dismissing it outright.
|
252 |
+
|
253 |
+
3. Provide your response in the following format:
|
254 |
+
- **Reasoning**: Explain your thought process and observations. Mention specific elements in the image that indicate necessary steps, evidence, or lack thereof.
|
255 |
+
- **Score**: Assign a score based on the reasoning, using the following scale:
|
256 |
+
- **1**: The image does not contain any necessary steps or relevant information.
|
257 |
+
- **2**: The image contains minimal or ambiguous information, unlikely to be essential.
|
258 |
+
- **3**: The image includes some relevant steps or hints but lacks clarity or completeness.
|
259 |
+
- **4**: The image contains important steps or evidence that are highly relevant but not fully comprehensive.
|
260 |
+
- **5**: The image clearly displays necessary steps or evidence crucial for completing the task.
|
261 |
+
|
262 |
+
Respond with:
|
263 |
+
1. **Reasoning**: [Your explanation]
|
264 |
+
2. **Score**: [1-5]"""
|
265 |
+
|
266 |
+
# jpg_base64_str = encode_image(Image.open(image_path))
|
267 |
+
|
268 |
+
prompt = """**Task**: {task}
|
269 |
+
|
270 |
+
**Key Points for Task Completion**: {key_points}
|
271 |
+
|
272 |
+
The snapshot of the web page is shown in the image."""
|
273 |
+
text = prompt.format(task=task, key_points=key_points)
|
274 |
+
|
275 |
+
messages = [
|
276 |
+
{"role": "system", "content": system_msg},
|
277 |
+
{
|
278 |
+
"role": "user",
|
279 |
+
"content": [
|
280 |
+
{"type": "text", "text": text},
|
281 |
+
{
|
282 |
+
"type": "image_url",
|
283 |
+
"image_url": {"url": image_path, "detail": "high"},
|
284 |
+
},
|
285 |
+
],
|
286 |
+
}
|
287 |
+
]
|
288 |
+
|
289 |
+
return messages
|
290 |
+
|
291 |
+
|
292 |
+
def WebJudge_Online_Mind2Web_eval(task, last_actions, images_path, image_responses, key_points, score_threshold):
|
293 |
+
system_msg = """You are an expert in evaluating the performance of a web navigation agent. The agent is designed to help a human user navigate a website to complete a task. Given the user's task, the agent's action history, key points for task completion, some potentially important web pages in the agent's trajectory and their reasons, your goal is to determine whether the agent has completed the task and achieved all requirements.
|
294 |
+
|
295 |
+
Your response must strictly follow the following evaluation criteria!
|
296 |
+
*Important Evaluation Criteria*:
|
297 |
+
1: The filtered results must be displayed correctly. If filters were not properly applied (i.e., missing selection, missing confirmation, or no visible effect in results), the task is not considered successful.
|
298 |
+
2: You must carefully check whether these snapshots and action history meet these key points. Ensure that specific filter conditions, such as "best," "highest," "cheapest," "latest," "most recent," "lowest," "closest," "highest-rated," "largest," and "newest" are correctly applied using the filter function(e.g., sort function).
|
299 |
+
3: Certain key points or requirements should be applied by the filter. Otherwise, a search with all requirements as input will be deemed a failure since it cannot guarantee that all results meet the requirements!
|
300 |
+
4: If the task requires filtering by a specific range of money, years, or the number of beds and bathrooms, the applied filter must exactly match the given requirement. Any deviation results in failure. To ensure the task is successful, the applied filter must precisely match the specified range without being too broad or too narrow.
|
301 |
+
Examples of Failure Cases:
|
302 |
+
- If the requirement is less than $50, but the applied filter is less than $25, it is a failure.
|
303 |
+
- If the requirement is $1500-$2500, but the applied filter is $2000-$2500, it is a failure.
|
304 |
+
- If the requirement is $25-$200, but the applied filter is $0-$200, it is a failure.
|
305 |
+
- If the required years are 2004-2012, but the filter applied is 2001-2012, it is a failure.
|
306 |
+
- If the required years are before 2015, but the applied filter is 2000-2014, it is a failure.
|
307 |
+
- If the task requires exactly 2 beds, but the filter applied is 2+ beds, it is a failure.
|
308 |
+
5: Some tasks require a submission action or a display of results to be considered successful.
|
309 |
+
6: If the retrieved information is invalid or empty(e.g., No match was found), but the agent has correctly performed the required action, it should still be considered successful.
|
310 |
+
7: If the current page already displays all available items, then applying a filter is not necessary. As long as the agent selects items that meet the requirements (e.g., the cheapest or lowest price), the task is still considered successful.
|
311 |
+
|
312 |
+
*IMPORTANT*
|
313 |
+
Format your response into two lines as shown below:
|
314 |
+
|
315 |
+
Thoughts: <your thoughts and reasoning process based on double-checking each key points and the evaluation criteria>
|
316 |
+
Status: "success" or "failure"
|
317 |
+
"""
|
318 |
+
prompt = """User Task: {task}
|
319 |
+
|
320 |
+
Key Points: {key_points}
|
321 |
+
|
322 |
+
Action History:
|
323 |
+
{last_actions}
|
324 |
+
|
325 |
+
The potentially important snapshots of the webpage in the agent's trajectory and their reasons:
|
326 |
+
{thoughts}"""
|
327 |
+
|
328 |
+
whole_content_img = []
|
329 |
+
whole_thoughts = []
|
330 |
+
record = []
|
331 |
+
pattern = r"[1-5]"
|
332 |
+
for response, image_path in zip(image_responses, images_path):
|
333 |
+
try:
|
334 |
+
score_text = response.split("Score")[1]
|
335 |
+
thought = response.split("**Reasoning**:")[-1].strip().lstrip("\n").split("\n\n")[0].replace('\n', ' ')
|
336 |
+
score = re.findall(pattern, score_text)[0]
|
337 |
+
record.append({"Response": response, "Score": int(score)})
|
338 |
+
except Exception as e:
|
339 |
+
print(f"Error processing response: {e}")
|
340 |
+
score = 0
|
341 |
+
record.append({"Response": response, "Score": 0})
|
342 |
+
|
343 |
+
if int(score) >= score_threshold:
|
344 |
+
# jpg_base64_str = encode_image(Image.open(image_path))
|
345 |
+
whole_content_img.append(
|
346 |
+
{
|
347 |
+
'type': 'image_url',
|
348 |
+
"image_url": {"url": image_path, "detail": "high"},
|
349 |
+
}
|
350 |
+
)
|
351 |
+
if thought != "":
|
352 |
+
whole_thoughts.append(thought)
|
353 |
+
|
354 |
+
whole_content_img = whole_content_img[:MAX_IMAGE]
|
355 |
+
whole_thoughts = whole_thoughts[:MAX_IMAGE]
|
356 |
+
if len(whole_content_img) == 0:
|
357 |
+
prompt = """User Task: {task}
|
358 |
+
|
359 |
+
Key Points: {key_points}
|
360 |
+
|
361 |
+
Action History:
|
362 |
+
{last_actions}"""
|
363 |
+
text = prompt.format(task=task,
|
364 |
+
last_actions="\n".join(f"{i + 1}. {action}" for i, action in enumerate(last_actions)),
|
365 |
+
key_points=key_points,
|
366 |
+
thoughts="\n".join(f"{i + 1}. {thought}" for i, thought in enumerate(whole_thoughts)))
|
367 |
+
|
368 |
+
messages = [
|
369 |
+
{"role": "system", "content": system_msg},
|
370 |
+
{
|
371 |
+
"role": "user",
|
372 |
+
"content": [
|
373 |
+
{"type": "text", "text": text}]
|
374 |
+
+ whole_content_img
|
375 |
+
}
|
376 |
+
]
|
377 |
+
return messages, text, system_msg, record
|
378 |
+
|
379 |
+
|
380 |
+
# eval code end
|
381 |
+
|
382 |
+
class PlayWrightAgent(Agent):
|
383 |
+
|
384 |
+
def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs):
|
385 |
+
self.screen_capture = True
|
386 |
+
self.step_images = []
|
387 |
+
self.step_thoughts = []
|
388 |
+
self.step_actions = []
|
389 |
+
self.step_results = []
|
390 |
+
self.success = False
|
391 |
+
super().__init__(conf, **kwargs)
|
392 |
+
|
393 |
+
async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[
|
394 |
+
List[ActionModel], None]:
|
395 |
+
"""The strategy of an agent can be to decide which tools to use in the environment, or to delegate tasks to other agents.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
observation: The state observed from tools in the environment.
|
399 |
+
info: Extended information is used to assist the agent to decide a policy.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
ActionModel sequence from agent policy
|
403 |
+
"""
|
404 |
+
outputs = None
|
405 |
+
if kwargs.get("outputs") and isinstance(kwargs.get("outputs"), Outputs):
|
406 |
+
outputs = kwargs.get("outputs")
|
407 |
+
|
408 |
+
# Get current step information for trace recording
|
409 |
+
step = kwargs.get("step", 0)
|
410 |
+
exp_id = kwargs.get("exp_id", None)
|
411 |
+
source_span = trace.get_current_span()
|
412 |
+
|
413 |
+
if hasattr(observation, 'context') and observation.context:
|
414 |
+
self.task_histories = observation.context
|
415 |
+
|
416 |
+
self._finished = False
|
417 |
+
await self.async_desc_transform()
|
418 |
+
|
419 |
+
self.tools = None
|
420 |
+
if "data:image/jpeg;base64," in observation.content:
|
421 |
+
logger.info("transfer base64 content to image")
|
422 |
+
observation.image = observation.content
|
423 |
+
observation.content = "observation:"
|
424 |
+
self.step_images.append(observation.image)
|
425 |
+
|
426 |
+
images = observation.images if self.conf.use_vision else None
|
427 |
+
if self.conf.use_vision and not images and observation.image:
|
428 |
+
images = [observation.image]
|
429 |
+
|
430 |
+
messages = self.messages_transform(content=observation.content,
|
431 |
+
image_urls=images,
|
432 |
+
sys_prompt=self.system_prompt,
|
433 |
+
agent_prompt=self.agent_prompt)
|
434 |
+
|
435 |
+
self._log_messages(messages)
|
436 |
+
if isinstance(messages[-1]['content'], list):
|
437 |
+
messages[-1]['role'] = 'user' # 有image的话必须使用user请求,而且不写入历史对话
|
438 |
+
# self.memory.add(MemoryItem(
|
439 |
+
# content=messages[-1]['content'],
|
440 |
+
# metadata={
|
441 |
+
# "role": messages[-1]['role'],
|
442 |
+
# "agent_name": self.name(),
|
443 |
+
# }
|
444 |
+
# ))
|
445 |
+
else:
|
446 |
+
self.memory.add(MemoryItem(
|
447 |
+
content=messages[-1]['content'],
|
448 |
+
metadata={
|
449 |
+
"role": messages[-1]['role'],
|
450 |
+
"agent_name": self.name(),
|
451 |
+
}
|
452 |
+
))
|
453 |
+
|
454 |
+
|
455 |
+
|
456 |
+
llm_response = None
|
457 |
+
span_name = f"llm_call_{exp_id}"
|
458 |
+
with trace.span(span_name) as llm_span:
|
459 |
+
llm_span.set_attributes({
|
460 |
+
"exp_id": exp_id,
|
461 |
+
"step": step,
|
462 |
+
"messages": json.dumps([str(m) for m in messages], ensure_ascii=False)
|
463 |
+
})
|
464 |
+
if source_span:
|
465 |
+
source_span.set_attribute("messages", json.dumps([str(m) for m in messages], ensure_ascii=False))
|
466 |
+
|
467 |
+
try:
|
468 |
+
llm_response = await acall_llm_model(
|
469 |
+
self.llm,
|
470 |
+
messages=messages,
|
471 |
+
model=self.model_name,
|
472 |
+
# temperature=self.conf.llm_config.llm_temperature,
|
473 |
+
temperature=0.0,
|
474 |
+
tools=self.tools if not self.use_tools_in_prompt and self.tools else None,
|
475 |
+
stream=kwargs.get("stream", False)
|
476 |
+
)
|
477 |
+
|
478 |
+
# Record LLM response
|
479 |
+
llm_span.set_attributes({
|
480 |
+
"llm_response": json.dumps(llm_response.to_dict(), ensure_ascii=False),
|
481 |
+
"tool_calls": json.dumps([tool_call.model_dump() for tool_call in
|
482 |
+
llm_response.tool_calls] if llm_response.tool_calls else [],
|
483 |
+
ensure_ascii=False),
|
484 |
+
"error": llm_response.error if llm_response.error else ""
|
485 |
+
})
|
486 |
+
|
487 |
+
except Exception as e:
|
488 |
+
logger.warn(traceback.format_exc())
|
489 |
+
llm_span.set_attribute("error", str(e))
|
490 |
+
raise e
|
491 |
+
finally:
|
492 |
+
if llm_response:
|
493 |
+
use_tools = self.use_tool_list(llm_response)
|
494 |
+
is_use_tool_prompt = len(use_tools) > 0
|
495 |
+
if llm_response.error:
|
496 |
+
logger.info(f"llm result error: {llm_response.error}")
|
497 |
+
else:
|
498 |
+
self.memory.add(MemoryItem(
|
499 |
+
content=llm_response.content,
|
500 |
+
metadata={
|
501 |
+
"role": "assistant",
|
502 |
+
"agent_name": self.name(),
|
503 |
+
"tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools,
|
504 |
+
"is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False
|
505 |
+
}
|
506 |
+
))
|
507 |
+
|
508 |
+
function, origin_thought, origin_action, origin_result = parse_tool_call(
|
509 |
+
llm_response.message['content'])
|
510 |
+
self.step_thoughts.append(origin_thought)
|
511 |
+
self.step_actions.append(origin_action)
|
512 |
+
self.step_results.append(origin_result)
|
513 |
+
|
514 |
+
if function.name == "finished":
|
515 |
+
self._finished = True
|
516 |
+
llm_response.content = "<answer>" + llm_response.content + "</answer>"
|
517 |
+
llm_response.tool_calls = None
|
518 |
+
else:
|
519 |
+
llm_response.content = None
|
520 |
+
|
521 |
+
tool_call = ToolCall(
|
522 |
+
id="tooluse_mock",
|
523 |
+
type="function",
|
524 |
+
function=function,
|
525 |
+
)
|
526 |
+
screen_capture = ToolCall(
|
527 |
+
id="screen_capture",
|
528 |
+
type="function",
|
529 |
+
function=Function(
|
530 |
+
name="mcp__ms-playwright__browser_screen_capture",
|
531 |
+
arguments="{}"
|
532 |
+
)
|
533 |
+
)
|
534 |
+
llm_response.tool_calls = [tool_call, screen_capture]
|
535 |
+
else:
|
536 |
+
logger.error(f"{self.name()} failed to get LLM response")
|
537 |
+
raise RuntimeError(f"{self.name()} failed to get LLM response")
|
538 |
+
|
539 |
+
if outputs and isinstance(outputs, Outputs):
|
540 |
+
await outputs.add_output(MessageOutput(source=llm_response, json_parse=False))
|
541 |
+
|
542 |
+
agent_result = sync_exec(self.resp_parse_func, llm_response)
|
543 |
+
if not agent_result.is_call_tool:
|
544 |
+
self._finished = True
|
545 |
+
|
546 |
+
logger.info(self.step_thoughts)
|
547 |
+
logger.info(self.step_actions)
|
548 |
+
|
549 |
+
# now is eval code:
|
550 |
+
logger.info(f"step:{step}")
|
551 |
+
|
552 |
+
if self.finished or step >= 20: # 暂时写死,这里应该是max_step
|
553 |
+
task = self.task.split("Please first navigate to the target")[0]
|
554 |
+
key_points_messages = identify_key_points(task)
|
555 |
+
|
556 |
+
# eval_model_name = "shangshu.gpt-4o"
|
557 |
+
eval_model_name = self.model_name
|
558 |
+
tmp_llm_response = await acall_llm_model(
|
559 |
+
self.llm,
|
560 |
+
messages=key_points_messages,
|
561 |
+
model=eval_model_name,
|
562 |
+
temperature=0
|
563 |
+
)
|
564 |
+
|
565 |
+
key_points = tmp_llm_response.content
|
566 |
+
key_points = key_points.replace("\n\n", "\n")
|
567 |
+
|
568 |
+
try:
|
569 |
+
key_points = key_points.split("**Key Points**:")[1]
|
570 |
+
key_points = "\n".join(line.lstrip() for line in key_points.splitlines())
|
571 |
+
except:
|
572 |
+
key_points = key_points.split("Key Points:")[-1]
|
573 |
+
key_points = "\n".join(line.lstrip() for line in key_points.splitlines())
|
574 |
+
|
575 |
+
logger.info(f"key_points: {key_points}")
|
576 |
+
|
577 |
+
tasks_messages = [judge_image(task, image_path, key_points) for image_path in self.step_images]
|
578 |
+
|
579 |
+
# 这里暂时使用串行执行的写法
|
580 |
+
image_responses = []
|
581 |
+
for task_messages in tasks_messages:
|
582 |
+
logger.info(task_messages)
|
583 |
+
image_response = await acall_llm_model(
|
584 |
+
self.llm, # 假设这是你传给函数的第一个参数
|
585 |
+
messages=task_messages, # 每个请求的消息内容
|
586 |
+
model=eval_model_name, # 模型名称
|
587 |
+
temperature=0 # 温度参数
|
588 |
+
)
|
589 |
+
image_responses.append(image_response)
|
590 |
+
|
591 |
+
image_responses = [i.content for i in image_responses]
|
592 |
+
|
593 |
+
logger.info(f"image_responses: {image_responses}")
|
594 |
+
|
595 |
+
eval_messages, text, system_msg, record = WebJudge_Online_Mind2Web_eval(
|
596 |
+
self.task, self.step_actions, self.step_images, image_responses, key_points, 3)
|
597 |
+
response = await acall_llm_model(
|
598 |
+
self.llm,
|
599 |
+
messages=eval_messages,
|
600 |
+
model=eval_model_name,
|
601 |
+
temperature=0
|
602 |
+
)
|
603 |
+
eval_response = response.content
|
604 |
+
|
605 |
+
logger.info(f"eval_response: {eval_response}")
|
606 |
+
|
607 |
+
if "success" in eval_response.lower().split('status:')[1]:
|
608 |
+
self.success = True
|
609 |
+
|
610 |
+
# now is saving code:
|
611 |
+
|
612 |
+
result_dict = {
|
613 |
+
'task': task,
|
614 |
+
'images': self.step_images,
|
615 |
+
'actions': self.step_actions,
|
616 |
+
'thoughts': self.step_thoughts,
|
617 |
+
'results': self.step_results,
|
618 |
+
'success': self.success,
|
619 |
+
'final_answer': llm_response.content,
|
620 |
+
'eval_response': eval_response,
|
621 |
+
'is_done': self.finished,
|
622 |
+
'done_step': step,
|
623 |
+
}
|
624 |
+
result_dict = json.dumps(result_dict, ensure_ascii=False)
|
625 |
+
|
626 |
+
agent_result.actions[0].policy_info = result_dict
|
627 |
+
agent_result.actions[0].tool_name = None
|
628 |
+
agent_result.actions[0].action_name = None
|
629 |
+
agent_result.actions[0].agent_name = self.name()
|
630 |
+
# saving is over...
|
631 |
+
|
632 |
+
return agent_result.actions
|
633 |
+
|
634 |
+
|
635 |
+
class Pipeline(AworldBaseAgent):
|
636 |
+
class Valves(BaseModel):
|
637 |
+
llm_provider: Optional[str] = Field(default=None, description="llm_model_name")
|
638 |
+
llm_model_name: Optional[str] = Field(default=None, description="llm_model_name")
|
639 |
+
llm_base_url: Optional[str] = Field(default=None, description="llm_base_urly")
|
640 |
+
llm_api_key: Optional[str] = Field(default=None, description="llm api key")
|
641 |
+
system_prompt: str = Field(default=BROWSER_SYSTEM_PROMPT, description="system_prompt")
|
642 |
+
history_messages: int = Field(default=100, description="rounds of history messages")
|
643 |
+
|
644 |
+
def __init__(self):
|
645 |
+
self.valves = self.Valves()
|
646 |
+
self.agent_config = AgentConfig(
|
647 |
+
name=self.agent_name(),
|
648 |
+
llm_provider=self.valves.llm_provider if self.valves.llm_provider else os.environ.get("LLM_PROVIDER"),
|
649 |
+
llm_model_name=self.valves.llm_model_name if self.valves.llm_model_name else os.environ.get(
|
650 |
+
"LLM_MODEL_NAME"),
|
651 |
+
llm_api_key=self.valves.llm_api_key if self.valves.llm_api_key else os.environ.get("LLM_API_KEY"),
|
652 |
+
llm_base_url=self.valves.llm_base_url if self.valves.llm_base_url else os.environ.get("LLM_BASE_URL"),
|
653 |
+
system_prompt=self.valves.system_prompt if self.valves.system_prompt else BROWSER_SYSTEM_PROMPT
|
654 |
+
)
|
655 |
+
|
656 |
+
self.m2w_files = os.path.abspath(os.path.join(os.path.curdir, "aworldspace", "datasets", "online-mind2web"))
|
657 |
+
|
658 |
+
logging.info(f"m2w_files path {self.m2w_files}")
|
659 |
+
file_path = os.path.join(self.m2w_files, "Online_Mind2Web.json")
|
660 |
+
|
661 |
+
with open(file_path, 'r') as file:
|
662 |
+
self.full_dataset = json.load(file)
|
663 |
+
logging.info("playwright_agent init success")
|
664 |
+
|
665 |
+
# 重写build_agent
|
666 |
+
async def build_agent(self, body: dict):
|
667 |
+
agent_config = await self.get_agent_config(body)
|
668 |
+
mcp_servers = await self.get_mcp_servers(body)
|
669 |
+
|
670 |
+
agent = PlayWrightAgent(
|
671 |
+
conf=agent_config,
|
672 |
+
name=agent_config.name,
|
673 |
+
system_prompt=agent_config.system_prompt,
|
674 |
+
mcp_servers=mcp_servers,
|
675 |
+
mcp_config=await self.load_mcp_config(),
|
676 |
+
history_messages=await self.get_history_messages(body)
|
677 |
+
)
|
678 |
+
return agent
|
679 |
+
|
680 |
+
async def get_custom_input(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Any:
|
681 |
+
task = await self.get_m2w_task(int(user_message))
|
682 |
+
return task['Task']
|
683 |
+
|
684 |
+
async def get_agent_config(self, body):
|
685 |
+
default_llm_provider = self.valves.llm_provider if self.valves.llm_provider else os.environ.get("LLM_PROVIDER")
|
686 |
+
llm_model_name = self.valves.llm_model_name if self.valves.llm_model_name else os.environ.get("LLM_MODEL_NAME")
|
687 |
+
llm_api_key = self.valves.llm_api_key if self.valves.llm_api_key else os.environ.get("LLM_API_KEY")
|
688 |
+
llm_base_url = self.valves.llm_base_url if self.valves.llm_base_url else os.environ.get("LLM_BASE_URL")
|
689 |
+
system_prompt = self.valves.system_prompt if self.valves.system_prompt else BROWSER_SYSTEM_PROMPT
|
690 |
+
|
691 |
+
task = await self.get_task_from_body(body)
|
692 |
+
logging.info(
|
693 |
+
f"task llm config is: {task.llm_provider}, {task.llm_model_name}, {task.llm_api_key}, {task.llm_base_url}")
|
694 |
+
|
695 |
+
return AgentConfig(
|
696 |
+
name=self.agent_name(),
|
697 |
+
llm_provider=task.llm_provider if task and task.llm_provider else default_llm_provider,
|
698 |
+
llm_model_name=task.llm_model_name if task and task.llm_model_name else llm_model_name,
|
699 |
+
llm_api_key=task.llm_api_key if task and task.llm_api_key else llm_api_key,
|
700 |
+
llm_base_url=task.llm_base_url if task and task.llm_base_url else llm_base_url,
|
701 |
+
system_prompt=task.task_system_prompt if task and task.task_system_prompt else system_prompt
|
702 |
+
)
|
703 |
+
|
704 |
+
def agent_name(self) -> str:
|
705 |
+
return "PlaywrightAgent"
|
706 |
+
|
707 |
+
async def get_mcp_servers(self, body) -> list[str]:
|
708 |
+
task = await self.get_task_from_body(body)
|
709 |
+
if task.mcp_servers:
|
710 |
+
logging.info(f"mcp_servers from task: {task.mcp_servers}")
|
711 |
+
return task.mcp_servers
|
712 |
+
|
713 |
+
return [
|
714 |
+
"ms-playwright"
|
715 |
+
]
|
716 |
+
|
717 |
+
async def get_m2w_task(self, index) -> dict:
|
718 |
+
logging.info(f"Start to process: m2w_task_{index}")
|
719 |
+
m2w_task = self.full_dataset[index]
|
720 |
+
logging.info(f"Detail: {m2w_task}")
|
721 |
+
logging.info(f"Task: {m2w_task['confirmed_task']}")
|
722 |
+
logging.info(f"Level: {m2w_task['level']}")
|
723 |
+
logging.info(f"Website: {m2w_task['website']}")
|
724 |
+
|
725 |
+
return self.add_file_path(m2w_task)
|
726 |
+
|
727 |
+
async def custom_output_before_task(self, outputs: Outputs, chat_id: str, task: Task) -> None:
|
728 |
+
task_config: TaskConfig = task.conf
|
729 |
+
m2w_task = await self.get_m2w_task(int(task_config.ext['origin_message']))
|
730 |
+
|
731 |
+
result = f"\n\n`Web TASK#{task_config.ext['origin_message']}`\n\n---\n\n"
|
732 |
+
result += f"**Task**: {m2w_task['Task']}\n"
|
733 |
+
result += f"**Level**: {m2w_task['level']}\n"
|
734 |
+
result += f"**Website**: \n {m2w_task['website']}\n"
|
735 |
+
result += f"\n\n-----\n\n"
|
736 |
+
await outputs.add_output(Output(data=result))
|
737 |
+
|
738 |
+
async def custom_output_after_task(self, outputs: Outputs, chat_id: str, task: Task):
|
739 |
+
"""
|
740 |
+
check gaia task output
|
741 |
+
Args:
|
742 |
+
outputs:
|
743 |
+
chat_id:
|
744 |
+
task:
|
745 |
+
|
746 |
+
Returns:
|
747 |
+
|
748 |
+
"""
|
749 |
+
task_config: TaskConfig = task.conf
|
750 |
+
web_task_id = int(task_config['ext']['origin_message'])
|
751 |
+
web_task = await self.get_m2w_task(web_task_id)
|
752 |
+
agent_result = ""
|
753 |
+
if isinstance(outputs, StreamingOutputs):
|
754 |
+
agent_result = await outputs._visited_outputs[-2].get_finished_response() # read llm result
|
755 |
+
# match = re.search(r"<answer>(.*?)</answer>", agent_result)
|
756 |
+
result = ""
|
757 |
+
# if match:
|
758 |
+
# answer = match.group(1)
|
759 |
+
logging.info(f"Agent answer: {agent_result}")
|
760 |
+
|
761 |
+
metadata = await outputs.get_metadata()
|
762 |
+
if not metadata:
|
763 |
+
await outputs.set_metadata({})
|
764 |
+
metadata = await outputs.get_metadata()
|
765 |
+
metadata['web_task'] = web_task
|
766 |
+
return result
|
767 |
+
|
768 |
+
def add_file_path(self, task: Dict[str, Any]
|
769 |
+
):
|
770 |
+
task["Task"] = "Task: " + task['confirmed_task'] + '\n' + "Please first navigate to the target " + "Website: " + \
|
771 |
+
task['website']
|
772 |
+
return task
|
773 |
+
|
774 |
+
async def load_mcp_config(self) -> dict:
|
775 |
+
return {
|
776 |
+
"mcpServers": {
|
777 |
+
"ms-playwright": {
|
778 |
+
"command": "npx",
|
779 |
+
"args": [
|
780 |
+
"@playwright/[email protected]",
|
781 |
+
"--vision",
|
782 |
+
"--no-sandbox",
|
783 |
+
"--headless",
|
784 |
+
"--isolated"
|
785 |
+
],
|
786 |
+
"env": {
|
787 |
+
"PLAYWRIGHT_TIMEOUT": "120000",
|
788 |
+
"SESSION_REQUEST_CONNECT_TIMEOUT": "120"
|
789 |
+
}
|
790 |
+
}
|
791 |
+
}
|
792 |
+
}
|