Duibonduil commited on
Commit
d3b8afc
·
verified ·
1 Parent(s): 900b15b

Upload 2 files

Browse files
AWorld-main/aworlddistributed/aworldspace/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from aworldspace.base_agent import AworldBaseAgent
6
+
7
+ """
8
+ Agent Space
9
+ """
10
+
11
+ class AgentMeta(BaseModel):
12
+ name: str = None
13
+ desc: str = None
14
+
15
+
16
+
17
+ class AgentSpace(BaseModel):
18
+ agent_modules: Optional[dict] = Field(default_factory=dict, description="agent module")
19
+ agents_meta: Optional[dict] = Field(default_factory=dict, description="agents meta")
20
+
21
+ def register(self, agent_name: str, agent_instance: AworldBaseAgent, metadata: dict=None):
22
+ # Register agent metadata and instance
23
+ self.agent_modules[agent_name] = agent_instance
24
+
25
+ async def get_agent_modules(self):
26
+ return self.agent_modules
27
+
28
+ async def get_agents_meta(self):
29
+ return self.agents_meta
30
+
31
+
32
+ AGENT_SPACE = AgentSpace()
AWorld-main/aworlddistributed/aworldspace/base_agent.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import traceback
5
+ import uuid
6
+ from abc import abstractmethod
7
+ from typing import List, AsyncGenerator, Any
8
+
9
+ from aworld.config import AgentConfig, TaskConfig, ContextRuleConfig, OptimizationConfig
10
+ from aworld.agents.llm_agent import Agent
11
+ from aworld.core.task import Task
12
+ from aworld.output import WorkSpace, AworldUI, Outputs
13
+ from aworld.output.ui.markdown_aworld_ui import MarkdownAworldUI
14
+ from aworld.output.utils import load_workspace
15
+ from aworld.runner import Runners
16
+
17
+ from client.aworld_client import AworldTask
18
+
19
+
20
+ class AworldBaseAgent:
21
+
22
+ def pipes(self) -> list[dict]:
23
+ return [{"id": self.agent_name(), "name": self.agent_name()}]
24
+
25
+
26
+ @abstractmethod
27
+ def agent_name(self) -> str:
28
+ pass
29
+
30
+
31
+ async def pipe(
32
+ self,
33
+ user_message: str,
34
+ model_id: str,
35
+ messages: List[dict],
36
+ body: dict
37
+ ):
38
+
39
+ try:
40
+ logging.info(f"🤖{self.agent_name()} received user_message is {user_message}, form-data = {body}")
41
+
42
+ task = await self.get_task_from_body(body)
43
+
44
+ if task:
45
+ logging.info(f"🤖{self.agent_name()} received task is {task.task_id}_{task.client_id}_{task.user_id}")
46
+ task_id = task.task_id
47
+ else:
48
+ task_id = str(uuid.uuid4())
49
+
50
+ session_id = task_id
51
+ if body.get('metadata'):
52
+ # user_id = body.get('metadata').get('user_id')
53
+ session_id = body.get('metadata').get('chat_id', task_id)
54
+ task_id = body.get('metadata').get('message_id', task_id)
55
+
56
+ user_input = await self.get_custom_input(user_message, model_id, messages, body)
57
+ if task and task.llm_custom_input:
58
+ user_input = task.llm_custom_input
59
+ logging.info(f"🤖{self.agent_name()} call llm input is [{user_input}]")
60
+
61
+ # build agent task read from config
62
+ swarm = await self.build_swarm(body=body)
63
+ agent = None
64
+ if not swarm:
65
+ # build single agent task read from config
66
+ agent = await self.build_agent(body=body)
67
+ logging.info(f"🤖{self.agent_name()} build agent finished")
68
+
69
+
70
+
71
+ # return task
72
+ task = await self.build_task(agent=agent, task_id=task_id, user_input=user_input, user_message=user_message, body=body)
73
+ logging.info(f"🤖{self.agent_name()} build task finished, task_id is {task_id}")
74
+
75
+
76
+ workspace_type = os.environ.get("WORKSPACE_TYPE", "local")
77
+ workspace_path = os.environ.get("WORKSPACE_PATH", "./data/workspaces")
78
+ workspace = await load_workspace(session_id, workspace_type, workspace_path)
79
+ # render output
80
+ async_generator = await self.parse_task_output(session_id, task, workspace)
81
+
82
+ return async_generator()
83
+
84
+ except Exception as e:
85
+ return await self._format_exception(e)
86
+
87
+ async def _format_exception(self, e: Exception) -> str:
88
+ traceback.print_exc()
89
+ # tb_lines = traceback.format_exception(type(e), e, e.__traceback__)
90
+ # detailed_error = "".join(tb_lines)
91
+ # logging.error(e)
92
+ # return json.dumps({"error": detailed_error}, ensure_ascii=False)
93
+ return "💥💥💥process failed💥💥💥"
94
+
95
+
96
+ async def _format_error(self, status_code: int, error: bytes) -> str:
97
+ if isinstance(error, str):
98
+ error_str = error
99
+ else:
100
+ error_str = error.decode(errors="ignore")
101
+ try:
102
+ err_msg = json.loads(error_str).get("message", error_str)[:200]
103
+ except Exception:
104
+ err_msg = error_str[:200]
105
+ return json.dumps(
106
+ {"error": f"HTTP {status_code}: {err_msg}"}, ensure_ascii=False
107
+ )
108
+
109
+ async def get_custom_input(self, user_message: str,
110
+ model_id: str,
111
+ messages: List[dict],
112
+ body: dict) -> Any:
113
+ user_input = body["messages"][-1]["content"]
114
+ return user_input
115
+
116
+ @abstractmethod
117
+ async def get_history_messages(self, body) -> int:
118
+ task = await self.get_task_from_body(body)
119
+ if task:
120
+ return task.history_messages
121
+ return 100
122
+
123
+ @abstractmethod
124
+ async def get_agent_config(self, body) -> AgentConfig:
125
+ pass
126
+
127
+ @abstractmethod
128
+ async def get_mcp_servers(self, body) -> list[str]:
129
+ pass
130
+
131
+ async def build_agent(self, body: dict):
132
+
133
+ agent_config =await self.get_agent_config(body)
134
+ mcp_servers = await self.get_mcp_servers(body)
135
+ agent = Agent(
136
+ conf=agent_config,
137
+ name=agent_config.name,
138
+ system_prompt=agent_config.system_prompt,
139
+ mcp_servers=mcp_servers,
140
+ mcp_config=await self.load_mcp_config(),
141
+ history_messages=await self.get_history_messages(body),
142
+ context_rule=ContextRuleConfig(
143
+ optimization_config=OptimizationConfig(
144
+ enabled=False,
145
+ )
146
+ )
147
+ )
148
+ return agent
149
+
150
+ async def build_task(self, agent, task_id, user_input, user_message, body):
151
+ aworld_task = await self.get_task_from_body(body)
152
+ task = Task(
153
+ id=task_id,
154
+ name=task_id,
155
+ input=user_input,
156
+ agent=agent,
157
+ conf=TaskConfig(
158
+ task_id=task_id,
159
+ stream=False,
160
+ ext={
161
+ "origin_message": user_message
162
+ },
163
+ max_steps=aworld_task.max_steps if aworld_task else 100
164
+ )
165
+ )
166
+ return task
167
+
168
+
169
+
170
+ async def parse_task_output(self, chat_id, task: Task, workspace: WorkSpace):
171
+ _SENTINEL = object()
172
+
173
+ async def async_generator():
174
+
175
+ from asyncio import Queue
176
+ queue = Queue()
177
+
178
+ async def consume_all():
179
+ openwebui_ui = MarkdownAworldUI(
180
+ session_id=chat_id,
181
+ workspace=workspace
182
+ )
183
+
184
+ # get outputs
185
+ outputs = Runners.streamed_run_task(task)
186
+
187
+ # output hooks
188
+ await self.custom_output_before_task(outputs, chat_id, task)
189
+
190
+ # render output
191
+ try:
192
+ async for output in outputs.stream_events():
193
+ res = await AworldUI.parse_output(output, openwebui_ui)
194
+ if res:
195
+ if isinstance(res, AsyncGenerator):
196
+ async for item in res:
197
+ await queue.put(item)
198
+ else:
199
+ await queue.put(res)
200
+ custom_output = await self.custom_output_after_task(outputs, chat_id, task)
201
+ if custom_output:
202
+ await queue.put(custom_output)
203
+ await queue.put(task)
204
+ finally:
205
+ await queue.put(_SENTINEL)
206
+
207
+ # Start the consumer in the background
208
+ import asyncio
209
+ consumer_task = asyncio.create_task(consume_all())
210
+
211
+ while True:
212
+ item = await queue.get()
213
+ if item is _SENTINEL:
214
+ break
215
+ yield item
216
+ await consumer_task
217
+ logging.info(f"🤖{self.agent_name()} task#{task.id} output finished🔚🔚🔚")
218
+
219
+ return async_generator
220
+
221
+ async def custom_output_before_task(self, outputs: Outputs, chat_id: str, task: Task) -> str | None:
222
+ return None
223
+
224
+ async def custom_output_after_task(self, outputs: Outputs, chat_id: str, task: Task):
225
+ pass
226
+
227
+ async def get_task_from_body(self, body: dict) -> AworldTask | None:
228
+ try:
229
+ if not body.get("user") or not body.get("user").get("aworld_task"):
230
+ return None
231
+ return AworldTask.model_validate_json(body.get("user").get("aworld_task"))
232
+ except Exception as err:
233
+ logging.error(f"Error parsing AworldTask: {err}; data: {body.get('user_message')}")
234
+ traceback.print_exc()
235
+ return None
236
+
237
+ @abstractmethod
238
+ async def load_mcp_config(self) -> dict:
239
+ pass
240
+
241
+ async def build_swarm(self, body):
242
+ return None