Duibonduil commited on
Commit
297ed13
·
verified ·
1 Parent(s): 4081a7b

Upload 3 files

Browse files
aworld/cmd/utils/agent_loader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ import subprocess
4
+ import sys
5
+ import traceback
6
+ import logging
7
+ from typing import List, Dict
8
+ from .. import AgentModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _agent_cache: Dict[str, AgentModel] = {}
13
+
14
+
15
+ def list_agents() -> List[AgentModel]:
16
+ """
17
+ List all cached agents
18
+
19
+ Returns:
20
+ List[AgentModel]: The list of agent models
21
+ """
22
+ if len(_agent_cache) == 0:
23
+ for m in _list_agents():
24
+ _agent_cache[m.id] = m
25
+ return _agent_cache
26
+
27
+
28
+ def get_agent(agent_id) -> AgentModel:
29
+ """
30
+ Get the agent model by agent name
31
+
32
+ Args:
33
+ agent_id: The name of the agent
34
+
35
+ Returns:
36
+ AgentModel: The agent model
37
+ """
38
+ if len(_agent_cache) == 0:
39
+ list_agents()
40
+ if agent_id not in _agent_cache:
41
+ raise Exception(f"Agent {agent_id} not found")
42
+ return _agent_cache[agent_id]
43
+
44
+
45
+ def _list_agents() -> List[AgentModel]:
46
+ agents_dir = os.path.join(os.getcwd(), "agent_deploy")
47
+
48
+ if not os.path.exists(agents_dir):
49
+ logger.warning(f"Agents directory {agents_dir} does not exist")
50
+ return []
51
+
52
+ if agents_dir not in sys.path:
53
+ sys.path.append(agents_dir)
54
+
55
+ agents = []
56
+ for agent_id in os.listdir(agents_dir):
57
+ try:
58
+ agent_path = os.path.join(agents_dir, agent_id)
59
+ if os.path.isdir(agent_path):
60
+ requirements_file = os.path.join(agent_path, "requirements.txt")
61
+ if os.path.exists(requirements_file):
62
+ p = subprocess.Popen(
63
+ ["pip", "install", "-r", requirements_file],
64
+ cwd=agent_path,
65
+ )
66
+ p.wait()
67
+ if p.returncode != 0:
68
+ logger.error(
69
+ f"Error installing requirements for agent {agent_id}, path {agent_path}"
70
+ )
71
+ continue
72
+
73
+ agent_file = os.path.join(agent_path, "agent.py")
74
+ if os.path.exists(agent_file):
75
+ try:
76
+ instance = _get_agent_instance(agent_id)
77
+ if hasattr(instance, "name"):
78
+ name = instance.name()
79
+ else:
80
+ name = agent_id
81
+ if hasattr(instance, "description"):
82
+ description = instance.description()
83
+ else:
84
+ description = ""
85
+ agent_model = AgentModel(
86
+ id=agent_id,
87
+ name=name,
88
+ description=description,
89
+ path=agent_path,
90
+ instance=instance,
91
+ )
92
+
93
+ agents.append(agent_model)
94
+ logger.info(
95
+ f"Loaded agent {agent_id} successfully, path {agent_path}"
96
+ )
97
+ except Exception as e:
98
+ logger.error(
99
+ f"Error loading agent {agent_id}: {traceback.format_exc()}"
100
+ )
101
+ continue
102
+ else:
103
+ logger.warning(f"Agent {agent_id} does not have agent.py file")
104
+ except Exception as e:
105
+ logger.error(
106
+ f"Error loading agent {agent_id}, path {agent_path} : {traceback.format_exc()}"
107
+ )
108
+ continue
109
+
110
+ return agents
111
+
112
+
113
+ def _get_agent_instance(agent_name):
114
+ try:
115
+ agent_module = importlib.import_module(
116
+ name=f"{agent_name}.agent",
117
+ )
118
+ except Exception as e:
119
+ msg = f"Error loading agent {agent_name}, cwd:{os.getcwd()}, sys.path:{sys.path}: {traceback.format_exc()}"
120
+ logger.error(msg)
121
+ raise Exception(msg)
122
+
123
+ if hasattr(agent_module, "AWorldAgent"):
124
+ agent = agent_module.AWorldAgent()
125
+ return agent
126
+ else:
127
+ raise Exception(f"Agent {agent_name} does not have AWorldAgent class")
aworld/cmd/utils/agent_server.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from aworld.cmd import AgentModel, ChatCompletionMessage, ChatCompletionRequest
4
+ from aworld.session.base_session_service import BaseSessionService
5
+ from aworld.session.simple_session_service import SimpleSessionService
6
+
7
+
8
+ class AgentServer:
9
+ def __init__(
10
+ self,
11
+ server_id: str,
12
+ server_name: str,
13
+ session_service: BaseSessionService = None,
14
+ ):
15
+ """
16
+ Initialize AgentServer
17
+ """
18
+ self.server_id = server_id
19
+ self.server_name = server_name
20
+ self.agent_list = []
21
+ self.session_service = session_service or SimpleSessionService()
22
+
23
+ def get_session_service(self) -> BaseSessionService:
24
+ return self.session_service
25
+
26
+ def get_agent_list(self) -> List[AgentModel]:
27
+ return self.agent_list
28
+
29
+ async def on_chat_completion_request(self, request: ChatCompletionRequest):
30
+ await self.get_session_service().append_messages(
31
+ request.user_id,
32
+ request.session_id,
33
+ request.messages,
34
+ )
35
+
36
+ async def on_chat_completion_end(
37
+ self, request: ChatCompletionRequest, final_response: str
38
+ ):
39
+ await self.get_session_service().append_messages(
40
+ request.user_id,
41
+ request.session_id,
42
+ [
43
+ ChatCompletionMessage(
44
+ role="assistant",
45
+ content=final_response,
46
+ trace_id=request.trace_id,
47
+ ),
48
+ ],
49
+ )
50
+
51
+
52
+ CURRENT_SERVER = AgentServer(server_id="default", server_name="default")
aworld/cmd/utils/agent_ui_parser.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import uuid
3
+ from dataclasses import dataclass
4
+
5
+ from pydantic import Field, BaseModel
6
+
7
+ from aworld.output import (
8
+ MessageOutput,
9
+ AworldUI,
10
+ Output,
11
+ Artifact,
12
+ ArtifactType,
13
+ WorkSpace,
14
+ SearchOutput,
15
+ )
16
+ from aworld.output.base import StepOutput, ToolResultOutput
17
+ from aworld.output.utils import consume_content
18
+ from abc import ABC, abstractmethod
19
+ from typing_extensions import override
20
+
21
+
22
+ class ToolCard(BaseModel):
23
+ tool_type: str = Field(None, description="tool type")
24
+ tool_name: str = Field(None, description="tool name")
25
+ function_name: str = Field(None, description="function name")
26
+ tool_call_id: str = Field(None, description="tool call id")
27
+ arguments: str = Field(None, description="arguments")
28
+ results: str = Field(None, description="results")
29
+ card_type: str = Field(None, description="card type")
30
+ card_data: dict = Field(None, description="card data")
31
+
32
+ @staticmethod
33
+ def from_tool_result(output: ToolResultOutput) -> "ToolCard":
34
+ return ToolCard(
35
+ tool_type=output.tool_type,
36
+ tool_name=output.tool_name,
37
+ function_name=output.origin_tool_call.function.name,
38
+ tool_call_id=output.origin_tool_call.id,
39
+ arguments=output.origin_tool_call.function.arguments,
40
+ results=output.data,
41
+ )
42
+
43
+
44
+ class BaseToolResultParser(ABC):
45
+
46
+ def __init__(self, tool_name: str = None):
47
+ self.tool_name = tool_name
48
+
49
+ @abstractmethod
50
+ async def parse(self, output: ToolResultOutput):
51
+ pass
52
+
53
+
54
+ class DefaultToolResultParser(BaseToolResultParser):
55
+
56
+ @override
57
+ async def parse(self, output: ToolResultOutput):
58
+ tool_card = ToolCard.from_tool_result(output)
59
+
60
+ tool_card.card_type = "tool_call_card_default"
61
+
62
+ return f"""\
63
+ **🔧 Tool: {tool_card.tool_name}#{tool_card.function_name}**\n\n
64
+ ```tool_card
65
+ {json.dumps(tool_card.model_dump(), ensure_ascii=False, indent=2)}
66
+ ```
67
+ """
68
+
69
+
70
+ class GooglePseSearchToolResultParser(BaseToolResultParser):
71
+
72
+ @override
73
+ async def parse(self, output: ToolResultOutput):
74
+ tool_card = ToolCard.from_tool_result(output)
75
+
76
+ query = ""
77
+ try:
78
+ args = json.loads(tool_card.arguments)
79
+ query = args.get("query")
80
+ except Exception:
81
+ pass
82
+
83
+ result_items = []
84
+ try:
85
+ result_items = json.loads(tool_card.results)
86
+ except Exception:
87
+ pass
88
+
89
+ tool_card.card_type = "tool_call_card_link_list"
90
+ tool_card.card_data = {
91
+ "title": "🔎 Google Search",
92
+ "query": query,
93
+ "search_items": result_items,
94
+ }
95
+
96
+ return f"""\
97
+ **🔎 Google Search**\n\n
98
+ ```tool_card
99
+ {json.dumps(tool_card.model_dump(), ensure_ascii=False, indent=2)}
100
+ ```
101
+ """
102
+
103
+
104
+ class ToolResultParserFactory:
105
+ _parsers = {}
106
+
107
+ @staticmethod
108
+ def register_parser(parser: BaseToolResultParser):
109
+ ToolResultParserFactory._parsers[parser.tool_name] = parser
110
+
111
+ @staticmethod
112
+ def get_parser(tool_type: str, tool_name: str):
113
+ if "search" in tool_name:
114
+ return GooglePseSearchToolResultParser()
115
+ else:
116
+ return DefaultToolResultParser()
117
+
118
+
119
+ @dataclass
120
+ class AWorldAgentUI(AworldUI):
121
+
122
+ session_id: str = Field(default="", description="session id")
123
+ workspace: WorkSpace = Field(default=None, description="workspace")
124
+ cur_agent_name: str = Field(default=None, description="cur agent name")
125
+
126
+ def __init__(self, session_id: str = None, workspace: WorkSpace = None, **kwargs):
127
+ """
128
+ Initialize MarkdownAworldUI
129
+ Args:"""
130
+ super().__init__(**kwargs)
131
+ self.session_id = session_id
132
+ self.workspace = workspace
133
+
134
+ @override
135
+ async def message_output(self, __output__: MessageOutput):
136
+ """
137
+ Returns an async generator that yields each message item.
138
+ """
139
+ # Sentinel object for queue completion
140
+ _SENTINEL = object()
141
+
142
+ async def async_generator():
143
+ async def __log_item(item):
144
+ await queue.put(item)
145
+
146
+ from asyncio import Queue
147
+
148
+ queue = Queue()
149
+
150
+ async def consume_all():
151
+ # Consume all relevant generators
152
+ if __output__.reason_generator or __output__.response_generator:
153
+ if __output__.reason_generator:
154
+ await consume_content(__output__.reason_generator, __log_item)
155
+ if __output__.response_generator:
156
+ await consume_content(__output__.response_generator, __log_item)
157
+ else:
158
+ await consume_content(__output__.reasoning, __log_item)
159
+ await consume_content(__output__.response, __log_item)
160
+ # Only after all are done, put the sentinel
161
+ await queue.put(_SENTINEL)
162
+
163
+ # Start the consumer in the background
164
+ import asyncio
165
+
166
+ consumer_task = asyncio.create_task(consume_all())
167
+
168
+ while True:
169
+ item = await queue.get()
170
+ if item is _SENTINEL:
171
+ break
172
+ yield item
173
+ await consumer_task # Ensure background task is finished
174
+
175
+ return async_generator()
176
+
177
+ @override
178
+ async def tool_result(self, output: ToolResultOutput):
179
+ """
180
+ tool_result
181
+ """
182
+ parser = ToolResultParserFactory.get_parser(output.tool_type, output.tool_name)
183
+ return await parser.parse(output)
184
+
185
+ async def _gen_custom_output(self, output):
186
+ """
187
+ hook for custom output
188
+ """
189
+ custom_output = f"{output.tool_name}#{output.origin_tool_call.function.name}"
190
+ if (
191
+ output.tool_name == "aworld-playwright"
192
+ and output.origin_tool_call.function.name == "browser_navigate"
193
+ ):
194
+ custom_output = f"🔍 search `{json.loads(output.origin_tool_call.function.arguments)['url']}`"
195
+ if (
196
+ output.tool_name == "aworldsearch-server"
197
+ and output.origin_tool_call.function.name == "search"
198
+ ):
199
+ custom_output = f"🔍 search keywords: {' '.join(json.loads(output.origin_tool_call.function.arguments)['query_list'])}"
200
+ return custom_output
201
+
202
+ @override
203
+ async def step(self, output: StepOutput):
204
+ emptyLine = "\n\n----\n\n"
205
+ if output.status == "START":
206
+ self.cur_agent_name = output.name
207
+ return f"\n\n # {output.show_name} \n\n"
208
+ elif output.status == "FINISHED":
209
+ return f"{emptyLine}"
210
+ elif output.status == "FAILED":
211
+ return f"\n\n{output.name} 💥FAILED: reason is {output.data} {emptyLine}"
212
+ else:
213
+ return f"\n\n{output.name} ❓❓❓UNKNOWN#{output.status} {emptyLine}"
214
+
215
+ @override
216
+ async def custom_output(self, output: Output):
217
+ return output.data
218
+
219
+ async def _parse_tool_artifacts(self, metadata):
220
+ result = []
221
+ if not metadata:
222
+ return result
223
+
224
+ # screenshots
225
+ if (
226
+ metadata.get("screenshots")
227
+ and isinstance(metadata.get("screenshots"), list)
228
+ and len(metadata.get("screenshots")) > 0
229
+ ):
230
+ for index, screenshot in enumerate(metadata.get("screenshots")):
231
+ image_artifact = Artifact(
232
+ artifact_id=str(uuid.uuid4()),
233
+ artifact_type=ArtifactType.IMAGE,
234
+ content=screenshot.get("ossPath"),
235
+ )
236
+ await self.workspace.add_artifact(image_artifact)
237
+ result.append(
238
+ {
239
+ "artifact_type": "IMAGE",
240
+ "artifact_id": image_artifact.artifact_id,
241
+ }
242
+ )
243
+
244
+ # web_pages
245
+ if metadata.get("artifact_type") == "WEB_PAGES":
246
+ search_output = SearchOutput.from_dict(metadata.get("artifact_data"))
247
+ artifact_id = str(uuid.uuid4())
248
+ await self.workspace.create_artifact(
249
+ artifact_type=ArtifactType.WEB_PAGES,
250
+ artifact_id=artifact_id,
251
+ content=search_output,
252
+ metadata={
253
+ "query": search_output.query,
254
+ },
255
+ )
256
+ result.append({"artifact_type": "WEB_PAGES", "artifact_id": artifact_id})
257
+ return result