Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- aworld/core/README.md +13 -0
- aworld/core/__init__.py +8 -0
- aworld/core/common.py +92 -0
- aworld/core/exceptions.py +7 -0
- aworld/core/experiment.py +7 -0
- aworld/core/factory.py +91 -0
- aworld/core/llm_provider_base.py +217 -0
- aworld/core/memory.py +305 -0
- aworld/core/runtime_engine.py +216 -0
- aworld/core/singleton.py +63 -0
- aworld/core/task.py +92 -0
aworld/core/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core Components
|
2 |
+
|
3 |
+
Common functionality and system components.
|
4 |
+
|
5 |
+
- `agent/`: Base agent for sub agents and description of already registered agents.
|
6 |
+
- `envs/`: The environment and its tools, as well as the related actions of the tools. It is a three-level and one to
|
7 |
+
many structure.
|
8 |
+
- `context`: to be continued
|
9 |
+
- `swarm`: Interactive collaboration in the topology structure of multiple agents that interact with the environment tools.
|
10 |
+
- `task`: Structure containing datasets, agents, tools, metrics, outputs, etc.
|
11 |
+
- `runner`: Complete a runnable specific workflow and obtain results.
|
12 |
+
|
13 |
+

|
aworld/core/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
try:
|
5 |
+
from aworld.agents import agent_desc
|
6 |
+
from examples.tools import tool_action_desc
|
7 |
+
except:
|
8 |
+
pass
|
aworld/core/common.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
from typing import Dict, Any, Optional, Union, List
|
6 |
+
|
7 |
+
from aworld.config import ConfigDict
|
8 |
+
from aworld.core.memory import MemoryItem
|
9 |
+
|
10 |
+
Config = Union[Dict[str, Any], ConfigDict, BaseModel]
|
11 |
+
|
12 |
+
|
13 |
+
class ActionResult(BaseModel):
|
14 |
+
"""Result of executing an action by use tool."""
|
15 |
+
is_done: bool = False
|
16 |
+
success: bool = False
|
17 |
+
content: Any = None
|
18 |
+
error: str = None
|
19 |
+
keep: bool = False
|
20 |
+
action_name: str = None
|
21 |
+
tool_name: str = None
|
22 |
+
# llm tool id
|
23 |
+
tool_id: str = None
|
24 |
+
metadata: Optional[Dict[str, Any]] = {}
|
25 |
+
|
26 |
+
|
27 |
+
class Observation(BaseModel):
|
28 |
+
"""Observation information is obtained from the tools or transformed from the actions made by agents.
|
29 |
+
|
30 |
+
It can be an agent(as a tool) in the swarm or a tool in the virtual environment.
|
31 |
+
"""
|
32 |
+
# default is None, means the main virtual environment or swarm
|
33 |
+
container_id: Optional[str] = None
|
34 |
+
# Observer who obtains observation, default is None for compatible, means an agent name or a tool name
|
35 |
+
observer: Optional[str] = None
|
36 |
+
# default is None for compatible, means with its action/ability name of an agent or a tool
|
37 |
+
# NOTE: The only ability of an agent as a tool is handoffs
|
38 |
+
ability: Optional[str] = None
|
39 |
+
# The agent wants the observation to be created, default is None for compatible.
|
40 |
+
from_agent_name: Optional[str] = None
|
41 |
+
# To which agent should the observation be given, default is None for compatible.
|
42 |
+
to_agent_name: Optional[str] = None
|
43 |
+
# general info for agent
|
44 |
+
content: Optional[Any] = None
|
45 |
+
# dom_tree is a str or DomTree object
|
46 |
+
dom_tree: Optional[Union[str, Any]] = None
|
47 |
+
image: Optional[str] = None # base64
|
48 |
+
action_result: Optional[List[ActionResult]] = []
|
49 |
+
# for video or image list
|
50 |
+
images: Optional[List[str]] = []
|
51 |
+
# extend key value pair. `done` is an internal key
|
52 |
+
info: Optional[Dict[str, Any]] = {}
|
53 |
+
|
54 |
+
|
55 |
+
class StatefulObservation(Observation):
|
56 |
+
"""Observations with contextual states."""
|
57 |
+
context: List[MemoryItem]
|
58 |
+
|
59 |
+
|
60 |
+
class ParamInfo(BaseModel):
|
61 |
+
name: str | None = None
|
62 |
+
type: str = "str"
|
63 |
+
required: bool = False
|
64 |
+
desc: str = None
|
65 |
+
default_value: Any = None
|
66 |
+
|
67 |
+
|
68 |
+
class ToolActionInfo(BaseModel):
|
69 |
+
name: str
|
70 |
+
input_params: Dict[str, ParamInfo] = {}
|
71 |
+
desc: str = None
|
72 |
+
|
73 |
+
|
74 |
+
class ActionModel(BaseModel):
|
75 |
+
tool_name: Optional[str] = None
|
76 |
+
tool_id: Optional[str] = None
|
77 |
+
# agent name
|
78 |
+
agent_name: Optional[str] = None
|
79 |
+
# action_name is a tool action name by agent policy.
|
80 |
+
action_name: Optional[str] = None
|
81 |
+
params: Optional[Dict[str, Any]] = {}
|
82 |
+
policy_info: Optional[Any] = None
|
83 |
+
|
84 |
+
|
85 |
+
class TaskItem(BaseModel):
|
86 |
+
data: Any
|
87 |
+
msg: str = None
|
88 |
+
stop: bool = False
|
89 |
+
success: bool = False
|
90 |
+
action_name: Optional[str] = None
|
91 |
+
params: Optional[Dict[str, Any]] = {}
|
92 |
+
policy_info: Optional[Any] = None
|
aworld/core/exceptions.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AworldException(Exception):
|
2 |
+
"""Exception Aworld."""
|
3 |
+
|
4 |
+
message: str
|
5 |
+
|
6 |
+
def __init__(self, message: str):
|
7 |
+
self.message = message
|
aworld/core/experiment.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class Experiment(BaseModel):
|
7 |
+
pass
|
aworld/core/factory.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
|
4 |
+
from typing import TypeVar, Generic, Dict, Any
|
5 |
+
|
6 |
+
from aworld.logs.util import logger
|
7 |
+
|
8 |
+
T = TypeVar("T")
|
9 |
+
|
10 |
+
|
11 |
+
class Factory(Generic[T]):
|
12 |
+
"""The base generic class that is used to define a factory(local) for various objects, with a parameterized types: T."""
|
13 |
+
|
14 |
+
def __init__(self, type_name: str = None):
|
15 |
+
self._type = type_name
|
16 |
+
self._cls: Dict[str, T] = {}
|
17 |
+
self._desc: Dict[str, str] = {}
|
18 |
+
self._asyn: Dict[str, bool] = {}
|
19 |
+
self._ext_info: Dict[str, Dict[Any, Any]] = {}
|
20 |
+
|
21 |
+
def __call__(self, name: str, asyn: bool = False, **kwargs):
|
22 |
+
"""Create the special type object instance by name. If not found, raise ValueError or construct object instance.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
Object instance.
|
26 |
+
"""
|
27 |
+
exception = kwargs.pop('except', False)
|
28 |
+
if not name in self._cls:
|
29 |
+
if not exception:
|
30 |
+
return None
|
31 |
+
|
32 |
+
if self._type is None:
|
33 |
+
raise ValueError(f"Unknown factory object type: '{self._type}'")
|
34 |
+
raise ValueError(f"Unknown {self._type}: '{name}'")
|
35 |
+
name = "async_" + name if asyn else name
|
36 |
+
return self._cls[name](**kwargs)
|
37 |
+
|
38 |
+
def __iter__(self):
|
39 |
+
for name in self._cls:
|
40 |
+
yield name
|
41 |
+
|
42 |
+
def __contains__(self, name: str) -> bool:
|
43 |
+
"""Whether the name in the factory."""
|
44 |
+
return name in self._cls
|
45 |
+
|
46 |
+
def get_class(self, name: str, asyn: bool = False) -> T | None:
|
47 |
+
"""Get the object instance by name."""
|
48 |
+
return self._cls.get(name, None)
|
49 |
+
|
50 |
+
def count(self) -> int:
|
51 |
+
"""Total number in the special type object factory."""
|
52 |
+
return len(self._cls)
|
53 |
+
|
54 |
+
def desc(self, name: str, asyn: bool = False) -> str:
|
55 |
+
"""Obtain the description by name."""
|
56 |
+
name = "async_" + name if asyn else name
|
57 |
+
return self._desc.get(name, "")
|
58 |
+
|
59 |
+
def get_ext_info(self, name: str, asyn: bool = False) -> Dict[Any, Any]:
|
60 |
+
"""Obtain the extent info by name."""
|
61 |
+
name = "async_" + name if asyn else name
|
62 |
+
return self._ext_info.get(name, {})
|
63 |
+
|
64 |
+
def register(self, name: str, desc: str, **kwargs):
|
65 |
+
def func(cls):
|
66 |
+
asyn = kwargs.pop("asyn", False)
|
67 |
+
prefix = "async_" if asyn else ""
|
68 |
+
if len(prefix) > 0:
|
69 |
+
logger.debug(f"{name} has an async type, will add `async_` prefix.")
|
70 |
+
|
71 |
+
if prefix + name in self._cls:
|
72 |
+
equal = True
|
73 |
+
if asyn:
|
74 |
+
equal = self._asyn[name] == asyn
|
75 |
+
if equal:
|
76 |
+
logger.warning(f"{name} already in {self._type} factory, will override it.")
|
77 |
+
|
78 |
+
self._asyn[name] = asyn
|
79 |
+
self._cls[prefix + name] = cls
|
80 |
+
self._desc[prefix + name] = desc
|
81 |
+
self._ext_info[prefix + name] = kwargs
|
82 |
+
return cls
|
83 |
+
|
84 |
+
return func
|
85 |
+
|
86 |
+
def unregister(self, name: str):
|
87 |
+
if name in self._cls:
|
88 |
+
logger.warning(f"unregister {name} in the {self._type} factory.")
|
89 |
+
del self._cls[name]
|
90 |
+
del self._desc[name]
|
91 |
+
del self._asyn[name]
|
aworld/core/llm_provider_base.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from collections import Counter
|
3 |
+
from typing import (
|
4 |
+
Any,
|
5 |
+
List,
|
6 |
+
Dict,
|
7 |
+
Generator,
|
8 |
+
AsyncGenerator,
|
9 |
+
)
|
10 |
+
|
11 |
+
from aworld.models.model_response import ModelResponse
|
12 |
+
|
13 |
+
|
14 |
+
class LLMProviderBase(abc.ABC):
|
15 |
+
"""Base class for large language model providers, defines unified interface."""
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
api_key: str = None,
|
19 |
+
base_url: str = None,
|
20 |
+
model_name: str = None,
|
21 |
+
sync_enabled: bool = None,
|
22 |
+
async_enabled: bool = None,
|
23 |
+
**kwargs):
|
24 |
+
"""Initialize provider.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
api_key: API key.
|
28 |
+
base_url: Service URL.
|
29 |
+
model_name: Model name.
|
30 |
+
**kwargs: Other parameters.
|
31 |
+
"""
|
32 |
+
self.api_key = api_key
|
33 |
+
self.base_url = base_url
|
34 |
+
self.model_name = model_name
|
35 |
+
self.kwargs = kwargs
|
36 |
+
# Determine whether to initialize sync and async providers
|
37 |
+
self.need_sync = sync_enabled if sync_enabled is not None else async_enabled is not True
|
38 |
+
self.need_async = async_enabled if async_enabled is not None else sync_enabled is not True
|
39 |
+
|
40 |
+
# Initialize providers based on flags
|
41 |
+
self.provider = self._init_provider() if self.need_sync else None
|
42 |
+
self.async_provider = self._init_async_provider() if self.need_async else None
|
43 |
+
self.stream_tool_buffer=[]
|
44 |
+
|
45 |
+
@abc.abstractmethod
|
46 |
+
def _init_provider(self):
|
47 |
+
"""Initialize specific provider instance, to be implemented by subclasses.
|
48 |
+
Returns:
|
49 |
+
Provider instance.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def _init_async_provider(self):
|
53 |
+
"""Initialize async provider instance. Optional for subclasses that don't need async support.
|
54 |
+
Only called if async provider is needed.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Async provider instance.
|
58 |
+
"""
|
59 |
+
return None
|
60 |
+
|
61 |
+
@classmethod
|
62 |
+
def supported_models(cls) -> list[str]:
|
63 |
+
return []
|
64 |
+
|
65 |
+
def preprocess_messages(self, messages: List[Dict[str, str]]) -> Any:
|
66 |
+
"""Preprocess messages, convert OpenAI format messages to specific provider required format.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
messages: OpenAI format message list [{"role": "system", "content": "..."}, ...].
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
Converted messages, format determined by specific provider.
|
73 |
+
"""
|
74 |
+
return messages
|
75 |
+
|
76 |
+
@abc.abstractmethod
|
77 |
+
def postprocess_response(self, response: Any) -> ModelResponse:
|
78 |
+
"""Post-process response, convert provider response to unified ModelResponse.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
response: Original response from provider.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
ModelResponse: Unified format response object.
|
85 |
+
|
86 |
+
Raises:
|
87 |
+
LLMResponseError: When LLM response error occurs.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def postprocess_stream_response(self, chunk: Any) -> ModelResponse:
|
91 |
+
"""Post-process streaming response chunk, convert provider chunk to unified ModelResponse.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
chunk: Original response chunk from provider.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
ModelResponse: Unified format response object for the chunk.
|
98 |
+
|
99 |
+
Raises:
|
100 |
+
LLMResponseError: When LLM response error occurs.
|
101 |
+
"""
|
102 |
+
|
103 |
+
async def acompletion(self,
|
104 |
+
messages: List[Dict[str, str]],
|
105 |
+
temperature: float = 0.0,
|
106 |
+
max_tokens: int = None,
|
107 |
+
stop: List[str] = None,
|
108 |
+
**kwargs) -> ModelResponse:
|
109 |
+
"""Asynchronously call model to generate response.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
|
113 |
+
temperature: Temperature parameter.
|
114 |
+
max_tokens: Maximum number of tokens to generate.
|
115 |
+
stop: List of stop sequences.
|
116 |
+
**kwargs: Other parameters.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
ModelResponse: Unified model response object.
|
120 |
+
|
121 |
+
Raises:
|
122 |
+
LLMResponseError: When LLM response error occurs.
|
123 |
+
RuntimeError: When async provider is not initialized.
|
124 |
+
"""
|
125 |
+
if not self.async_provider:
|
126 |
+
raise RuntimeError(
|
127 |
+
"Async provider not initialized. Make sure 'async_enabled' parameter is set to True in initialization.")
|
128 |
+
|
129 |
+
|
130 |
+
@abc.abstractmethod
|
131 |
+
def completion(self,
|
132 |
+
messages: List[Dict[str, str]],
|
133 |
+
temperature: float = 0.0,
|
134 |
+
max_tokens: int = None,
|
135 |
+
stop: List[str] = None,
|
136 |
+
**kwargs) -> ModelResponse:
|
137 |
+
"""Synchronously call model to generate response.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
|
141 |
+
temperature: Temperature parameter.
|
142 |
+
max_tokens: Maximum number of tokens to generate.
|
143 |
+
stop: List of stop sequences.
|
144 |
+
**kwargs: Other parameters.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
ModelResponse: Unified model response object.
|
148 |
+
|
149 |
+
Raises:
|
150 |
+
LLMResponseError: When LLM response error occurs.
|
151 |
+
RuntimeError: When sync provider is not initialized.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def stream_completion(self,
|
155 |
+
messages: List[Dict[str, str]],
|
156 |
+
temperature: float = 0.0,
|
157 |
+
max_tokens: int = None,
|
158 |
+
stop: List[str] = None,
|
159 |
+
**kwargs) -> Generator[ModelResponse, None, None]:
|
160 |
+
"""Synchronously call model to generate streaming response.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
|
164 |
+
temperature: Temperature parameter.
|
165 |
+
max_tokens: Maximum number of tokens to generate.
|
166 |
+
stop: List of stop sequences.
|
167 |
+
**kwargs: Other parameters.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Generator yielding ModelResponse chunks.
|
171 |
+
|
172 |
+
Raises:
|
173 |
+
LLMResponseError: When LLM response error occurs.
|
174 |
+
RuntimeError: When sync provider is not initialized.
|
175 |
+
"""
|
176 |
+
|
177 |
+
async def astream_completion(self,
|
178 |
+
messages: List[Dict[str, str]],
|
179 |
+
temperature: float = 0.0,
|
180 |
+
max_tokens: int = None,
|
181 |
+
stop: List[str] = None,
|
182 |
+
**kwargs) -> AsyncGenerator[ModelResponse, None]:
|
183 |
+
"""Asynchronously call model to generate streaming response.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
|
187 |
+
temperature: Temperature parameter.
|
188 |
+
max_tokens: Maximum number of tokens to generate.
|
189 |
+
stop: List of stop sequences.
|
190 |
+
**kwargs: Other parameters.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
AsyncGenerator yielding ModelResponse chunks.
|
194 |
+
|
195 |
+
Raises:
|
196 |
+
LLMResponseError: When LLM response error occurs.
|
197 |
+
RuntimeError: When async provider is not initialized.
|
198 |
+
"""
|
199 |
+
|
200 |
+
def _accumulate_chunk_usage(self, usage: Dict[str, int], chunk_usage: Dict[str, int]):
|
201 |
+
"""Accumulate usage statistics from chunk into the main usage dictionary.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
usage: Dictionary to accumulate usage into (will be modified)
|
205 |
+
chunk_usage: Usage statistics from the current chunk
|
206 |
+
"""
|
207 |
+
if not chunk_usage:
|
208 |
+
return
|
209 |
+
|
210 |
+
usage.update(dict(Counter(usage) + Counter(chunk_usage)))
|
211 |
+
|
212 |
+
def speech_to_text(self, audio_file, language, prompt, **kwargs) -> ModelResponse:
|
213 |
+
pass
|
214 |
+
|
215 |
+
async def aspeech_to_text(self, audio_file, language, prompt, **kwargs) -> ModelResponse:
|
216 |
+
pass
|
217 |
+
|
aworld/core/memory.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import uuid
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import Optional, Any, Literal, Union
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field, ConfigDict
|
7 |
+
|
8 |
+
from aworld.models.llm import LLMModel
|
9 |
+
|
10 |
+
|
11 |
+
class MemoryItem(BaseModel):
|
12 |
+
id: str = Field(description="id")
|
13 |
+
content: Any = Field(description="content")
|
14 |
+
created_at: Optional[str] = Field(None, description="created at")
|
15 |
+
updated_at: Optional[str] = Field(None, description="updated at")
|
16 |
+
metadata: dict = Field(
|
17 |
+
description="metadata, use to store additional information, such as user_id, agent_id, run_id, task_id, etc.")
|
18 |
+
tags: list[str] = Field(description="tags")
|
19 |
+
histories: list["MemoryItem"] = Field(default_factory=list)
|
20 |
+
deleted: bool = Field(default=False)
|
21 |
+
memory_type: Literal["init", "message", "summary", "agent_experience", "user_profile"] = Field(default="message")
|
22 |
+
version: int = Field(description="version")
|
23 |
+
|
24 |
+
def __init__(self, **data):
|
25 |
+
# Set default values for optional fields
|
26 |
+
if "id" not in data:
|
27 |
+
data["id"] = str(uuid.uuid4())
|
28 |
+
if "created_at" not in data:
|
29 |
+
data["created_at"] = datetime.datetime.now().isoformat()
|
30 |
+
if "updated_at" not in data:
|
31 |
+
data["updated_at"] = data["created_at"]
|
32 |
+
if "metadata" not in data:
|
33 |
+
data["metadata"] = {}
|
34 |
+
if "tags" not in data:
|
35 |
+
data["tags"] = []
|
36 |
+
if "version" not in data:
|
37 |
+
data["version"] = 1
|
38 |
+
|
39 |
+
super().__init__(**data)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def from_dict(cls, data: dict) -> "MemoryItem":
|
43 |
+
"""Create a MemoryItem instance from a dictionary.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
data (dict): A dictionary containing the memory item data.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
MemoryItem: An instance of MemoryItem.
|
50 |
+
"""
|
51 |
+
return cls(**data)
|
52 |
+
|
53 |
+
|
54 |
+
class MemoryStore(ABC):
|
55 |
+
"""
|
56 |
+
Memory store interface for messages history
|
57 |
+
"""
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def add(self, memory_item: MemoryItem):
|
61 |
+
pass
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def get(self, memory_id) -> Optional[MemoryItem]:
|
65 |
+
pass
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def get_first(self, filters: dict = None) -> Optional[MemoryItem]:
|
69 |
+
pass
|
70 |
+
|
71 |
+
@abstractmethod
|
72 |
+
def total_rounds(self, filters: dict = None) -> int:
|
73 |
+
pass
|
74 |
+
|
75 |
+
@abstractmethod
|
76 |
+
def get_all(self, filters: dict = None) -> list[MemoryItem]:
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def get_last_n(self, last_rounds, filters: dict = None) -> list[MemoryItem]:
|
81 |
+
pass
|
82 |
+
|
83 |
+
@abstractmethod
|
84 |
+
def update(self, memory_item: MemoryItem):
|
85 |
+
pass
|
86 |
+
|
87 |
+
@abstractmethod
|
88 |
+
def delete(self, memory_id):
|
89 |
+
pass
|
90 |
+
|
91 |
+
@abstractmethod
|
92 |
+
def history(self, memory_id) -> list[MemoryItem] | None:
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class MemoryBase(ABC):
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def get(self, memory_id) -> Optional[MemoryItem]:
|
100 |
+
"""Get item in memory by ID.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
memory_id (str): ID of the memory to retrieve.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
dict: Retrieved memory.
|
107 |
+
"""
|
108 |
+
|
109 |
+
@abstractmethod
|
110 |
+
def get_all(self, filters: dict = None) -> Optional[list[MemoryItem]]:
|
111 |
+
"""List all items in memory store.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
115 |
+
- user_id (str, optional): ID of the user to search for. Defaults to None.
|
116 |
+
- agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
117 |
+
- session_id (str, optional): ID of the session to search for. Defaults to None.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
list: List of all memories.
|
121 |
+
"""
|
122 |
+
|
123 |
+
@abstractmethod
|
124 |
+
def get_last_n(self, last_rounds, filters: dict = None) -> Optional[list[MemoryItem]]:
|
125 |
+
"""get last_rounds memories.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
last_rounds (int): Number of last rounds to retrieve.
|
129 |
+
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
130 |
+
- user_id (str, optional): ID of the user to search for. Defaults to None.
|
131 |
+
- agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
132 |
+
- session_id (str, optional): ID of the session to search for. Defaults to None.
|
133 |
+
Returns:
|
134 |
+
list: List of latest memories.
|
135 |
+
"""
|
136 |
+
|
137 |
+
@abstractmethod
|
138 |
+
def search(self, query, limit=100, filters=None) -> Optional[list[MemoryItem]]:
|
139 |
+
"""
|
140 |
+
Search for memories.
|
141 |
+
Hybrid search: Retrieve memories from vector store and memory store.
|
142 |
+
|
143 |
+
|
144 |
+
Args:
|
145 |
+
query (str): Query to search for.
|
146 |
+
limit (int, optional): Limit the number of results. Defaults to 100.
|
147 |
+
filters (dict, optional): Filters to apply to the search. Defaults to None.
|
148 |
+
- user_id (str, optional): ID of the user to search for. Defaults to None.
|
149 |
+
- agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
150 |
+
- session_id (str, optional): ID of the session to search for. Defaults to None.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
list: List of search results.
|
154 |
+
"""
|
155 |
+
|
156 |
+
@abstractmethod
|
157 |
+
def add(self, memory_item: MemoryItem, filters: dict = None):
|
158 |
+
"""Add memory in the memory store.
|
159 |
+
|
160 |
+
Step 1: Add memory to memory store
|
161 |
+
Step 2: Add memory to vector store
|
162 |
+
|
163 |
+
Args:
|
164 |
+
memory_item (MemoryItem): memory item.
|
165 |
+
metadata (dict, optional): metadata to add.
|
166 |
+
- user_id (str, optional): ID of the user to search for. Defaults to None.
|
167 |
+
- agent_id (str, optional): ID of the agent to search for. Defaults to None.
|
168 |
+
- session_id (str, optional): ID of the session to search for. Defaults to None.
|
169 |
+
tags (list, optional): tags to add.
|
170 |
+
memory_type (str, optional): memory type.
|
171 |
+
version (int, optional): version of the memory.
|
172 |
+
"""
|
173 |
+
|
174 |
+
@abstractmethod
|
175 |
+
def update(self, memory_item: MemoryItem):
|
176 |
+
"""Update a memory by ID.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
memory_item (MemoryItem): memory item.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
dict: Updated memory.
|
183 |
+
"""
|
184 |
+
|
185 |
+
@abstractmethod
|
186 |
+
async def async_gen_cur_round_summary(self, to_be_summary: MemoryItem, filters: dict, last_rounds: int) -> str:
|
187 |
+
"""A tool for reducing the context length of the current round.
|
188 |
+
|
189 |
+
Step 1: Retrieve historical conversation content based on filters and last_rounds
|
190 |
+
Step 2: Extract current round content and most relevant historical content
|
191 |
+
Step 3: Generate corresponding summary for the current round
|
192 |
+
|
193 |
+
Args:
|
194 |
+
to_be_summary (MemoryItem): msg to summary.
|
195 |
+
filters (dict): filters to get memory list.
|
196 |
+
last_rounds (int): last rounds of memory list.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
str: summary memory.
|
200 |
+
"""
|
201 |
+
|
202 |
+
@abstractmethod
|
203 |
+
async def async_gen_multi_rounds_summary(self, to_be_summary: list[MemoryItem]) -> str:
|
204 |
+
"""A tool for summarizing the list of memory item.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
to_be_summary (list[MemoryItem]): the list of memory item.
|
208 |
+
"""
|
209 |
+
|
210 |
+
@abstractmethod
|
211 |
+
async def async_gen_summary(self, filters: dict, last_rounds: int) -> str:
|
212 |
+
"""A tool for summarizing the conversation history.
|
213 |
+
|
214 |
+
Step 1: Retrieve historical conversation content based on filters and last_rounds
|
215 |
+
Step 2: Generate corresponding summary for conversation history
|
216 |
+
|
217 |
+
Args:
|
218 |
+
filters (dict): filters to get memory list.
|
219 |
+
last_rounds (int): last rounds of memory list.
|
220 |
+
"""
|
221 |
+
|
222 |
+
@abstractmethod
|
223 |
+
def delete(self, memory_id):
|
224 |
+
"""Delete a memory by ID.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
memory_id (str): ID of the memory to delete.
|
228 |
+
"""
|
229 |
+
|
230 |
+
|
231 |
+
SUMMARY_PROMPT = """
|
232 |
+
You are a helpful assistant that summarizes the conversation history.
|
233 |
+
- Summarize the following text in one clear and concise paragraph, capturing the key ideas without missing critical points.
|
234 |
+
- Ensure the summary is easy to understand and avoids excessive detail.
|
235 |
+
|
236 |
+
Here are the content:
|
237 |
+
{context}
|
238 |
+
"""
|
239 |
+
|
240 |
+
|
241 |
+
class MemoryConfig(BaseModel):
|
242 |
+
"""Configuration for procedural memory."""
|
243 |
+
|
244 |
+
model_config = ConfigDict(
|
245 |
+
from_attributes=True, validate_default=True, revalidate_instances='always', validate_assignment=True,
|
246 |
+
arbitrary_types_allowed=True
|
247 |
+
)
|
248 |
+
|
249 |
+
# Memory Config
|
250 |
+
provider: Literal['inmemory', 'mem0'] = 'inmemory'
|
251 |
+
enable_summary: bool = Field(default=False, description="enable_summary use llm to create summary memory")
|
252 |
+
summary_rounds: int = Field(default=5, description="rounds of message msg; when the number of messages is greater than the summary_rounds, the summary will be created")
|
253 |
+
summary_single_context_length: int = Field(default=4000, description=" when the content length is greater than the summary_single_context_length, the summary will be created")
|
254 |
+
summary_prompt: str = Field(default=SUMMARY_PROMPT, description="summary prompt")
|
255 |
+
|
256 |
+
# Embedder settings
|
257 |
+
embedder_provider: Literal['openai', 'gemini', 'ollama', 'huggingface'] = 'huggingface'
|
258 |
+
embedder_model: str = Field(min_length=2, default='all-MiniLM-L6-v2')
|
259 |
+
embedder_dims: int = Field(default=384, gt=10, lt=10000)
|
260 |
+
|
261 |
+
# LLM settings - the LLM instance can be passed separately
|
262 |
+
llm_provider: Literal['openai', 'langchain'] = 'langchain'
|
263 |
+
llm_instance: Optional[Union[LLMModel]] = None
|
264 |
+
|
265 |
+
# Vector store settings
|
266 |
+
vector_store_provider: Literal['faiss'] = 'faiss'
|
267 |
+
vector_store_base_path: str = Field(default='/tmp/mem0_aworld')
|
268 |
+
|
269 |
+
@property
|
270 |
+
def vector_store_path(self) -> str:
|
271 |
+
"""Returns the full vector store path for the current configuration. e.g. /tmp/mem0_384_faiss"""
|
272 |
+
return f'{self.vector_store_base_path}_{self.embedder_dims}_{self.vector_store_provider}'
|
273 |
+
|
274 |
+
@property
|
275 |
+
def embedder_config_dict(self) -> dict[str, Any]:
|
276 |
+
"""Returns the embedder configuration dictionary."""
|
277 |
+
return {
|
278 |
+
'provider': self.embedder_provider,
|
279 |
+
'config': {'model': self.embedder_model, 'embedding_dims': self.embedder_dims},
|
280 |
+
}
|
281 |
+
|
282 |
+
@property
|
283 |
+
def llm_config_dict(self) -> dict[str, Any]:
|
284 |
+
"""Returns the LLM configuration dictionary."""
|
285 |
+
return {'provider': self.llm_provider, 'config': {'model': self.llm_instance}}
|
286 |
+
|
287 |
+
@property
|
288 |
+
def vector_store_config_dict(self) -> dict[str, Any]:
|
289 |
+
"""Returns the vector store configuration dictionary."""
|
290 |
+
return {
|
291 |
+
'provider': self.vector_store_provider,
|
292 |
+
'config': {
|
293 |
+
'embedding_model_dims': self.embedder_dims,
|
294 |
+
'path': self.vector_store_path,
|
295 |
+
},
|
296 |
+
}
|
297 |
+
|
298 |
+
@property
|
299 |
+
def full_config_dict(self) -> dict[str, dict[str, Any]]:
|
300 |
+
"""Returns the complete configuration dictionary for Mem0."""
|
301 |
+
return {
|
302 |
+
'embedder': self.embedder_config_dict,
|
303 |
+
'llm': self.llm_config_dict,
|
304 |
+
'vector_store': self.vector_store_config_dict,
|
305 |
+
}
|
aworld/core/runtime_engine.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
import inspect
|
5 |
+
import os
|
6 |
+
import asyncio
|
7 |
+
from concurrent.futures import Future
|
8 |
+
from concurrent.futures.process import ProcessPoolExecutor
|
9 |
+
from types import MethodType
|
10 |
+
from typing import List, Callable, Any, Dict
|
11 |
+
|
12 |
+
from aworld.config import RunConfig, ConfigDict
|
13 |
+
from aworld.logs.util import logger
|
14 |
+
from aworld.utils.common import sync_exec
|
15 |
+
|
16 |
+
LOCAL = "local"
|
17 |
+
SPARK = "spark"
|
18 |
+
RAY = "ray"
|
19 |
+
K8S = "k8s"
|
20 |
+
|
21 |
+
|
22 |
+
class RuntimeEngine(object):
|
23 |
+
"""Lightweight wrapper of computing engine runtime."""
|
24 |
+
|
25 |
+
__metaclass__ = abc.ABCMeta
|
26 |
+
|
27 |
+
def __init__(self, conf: RunConfig):
|
28 |
+
"""Engine runtime instance initialize."""
|
29 |
+
self.conf = ConfigDict(conf.model_dump())
|
30 |
+
self.runtime = None
|
31 |
+
register(conf.name, self)
|
32 |
+
|
33 |
+
# Initialize clients running on top of distributed computing engines
|
34 |
+
pass
|
35 |
+
|
36 |
+
def build_engine(self) -> 'RuntimeEngine':
|
37 |
+
"""Create computing engine runtime.
|
38 |
+
|
39 |
+
If create more times in the same runtime instance, will get the same engine instance, like getOrCreate.
|
40 |
+
"""
|
41 |
+
if self.runtime is not None:
|
42 |
+
return self
|
43 |
+
self._build_engine()
|
44 |
+
return self
|
45 |
+
|
46 |
+
@abc.abstractmethod
|
47 |
+
def _build_engine(self) -> None:
|
48 |
+
raise NotImplementedError("Base _build_engine not implemented!")
|
49 |
+
|
50 |
+
@abc.abstractmethod
|
51 |
+
def broadcast(self, data: Any):
|
52 |
+
"""Broadcast the data to all workers."""
|
53 |
+
|
54 |
+
@abc.abstractmethod
|
55 |
+
async def execute(self, funcs: List[Callable[..., Any]], *args, **kwargs) -> Dict[str, Any]:
|
56 |
+
"""Submission focuses on the execution of stateless tasks on the special engine cluster."""
|
57 |
+
raise NotImplementedError("Base task execute not implemented!")
|
58 |
+
|
59 |
+
def pre_execute(self):
|
60 |
+
"""Define the pre execution logic."""
|
61 |
+
pass
|
62 |
+
|
63 |
+
def post_execute(self):
|
64 |
+
"""Define the post execution logic."""
|
65 |
+
pass
|
66 |
+
|
67 |
+
|
68 |
+
class LocalRuntime(RuntimeEngine):
|
69 |
+
"""Local runtime key is 'local', and execute tasks in local machine.
|
70 |
+
|
71 |
+
Local runtime is used to verify or test locally.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def _build_engine(self):
|
75 |
+
self.runtime = self
|
76 |
+
|
77 |
+
def func_wrapper(self, func, *args, **kwargs):
|
78 |
+
"""Function is used to adapter computing form."""
|
79 |
+
|
80 |
+
if inspect.iscoroutinefunction(func):
|
81 |
+
res = sync_exec(func, *args, **kwargs)
|
82 |
+
else:
|
83 |
+
res = func(*args, **kwargs)
|
84 |
+
return res
|
85 |
+
|
86 |
+
async def execute(self, funcs: List[Callable[..., Any]], *args, **kwargs) -> Dict[str, Any]:
|
87 |
+
# opt of the one task process
|
88 |
+
if len(funcs) == 1:
|
89 |
+
func = funcs[0]
|
90 |
+
if inspect.iscoroutinefunction(func):
|
91 |
+
res = await func(*args, **kwargs)
|
92 |
+
else:
|
93 |
+
res = func(*args, **kwargs)
|
94 |
+
return {res.id: res}
|
95 |
+
|
96 |
+
num_executor = self.conf.get('worker_num', os.cpu_count() - 1)
|
97 |
+
num_process = len(funcs)
|
98 |
+
if num_process > num_executor:
|
99 |
+
num_process = num_executor
|
100 |
+
|
101 |
+
if num_process <= 0:
|
102 |
+
num_process = 1
|
103 |
+
|
104 |
+
futures = []
|
105 |
+
with ProcessPoolExecutor(num_process) as pool:
|
106 |
+
for func in funcs:
|
107 |
+
futures.append(pool.submit(self.func_wrapper, func, *args, **kwargs))
|
108 |
+
|
109 |
+
results = {}
|
110 |
+
for future in futures:
|
111 |
+
future: Future = future
|
112 |
+
res = future.result()
|
113 |
+
results[res.id] = res
|
114 |
+
return results
|
115 |
+
|
116 |
+
|
117 |
+
class K8sRuntime(LocalRuntime):
|
118 |
+
"""K8s runtime key is 'k8s', and execute tasks in kubernetes cluster."""
|
119 |
+
|
120 |
+
|
121 |
+
class KubernetesRuntime(LocalRuntime):
|
122 |
+
"""kubernetes runtime key is 'kubernetes', and execute tasks in kubernetes cluster."""
|
123 |
+
|
124 |
+
|
125 |
+
class SparkRuntime(RuntimeEngine):
|
126 |
+
"""Spark runtime key is 'spark', and execute tasks in spark cluster.
|
127 |
+
|
128 |
+
Note: Spark runtime must in driver end.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, engine_options):
|
132 |
+
super(SparkRuntime, self).__init__(engine_options)
|
133 |
+
|
134 |
+
def _build_engine(self):
|
135 |
+
from pyspark.sql import SparkSession
|
136 |
+
|
137 |
+
conf = self.conf
|
138 |
+
is_local = conf.get('is_local', True)
|
139 |
+
logger.info('build runtime is_local:{}'.format(is_local))
|
140 |
+
spark_builder = SparkSession.builder
|
141 |
+
if is_local:
|
142 |
+
if 'PYSPARK_PYTHON' not in os.environ:
|
143 |
+
import sys
|
144 |
+
os.environ['PYSPARK_PYTHON'] = sys.executable
|
145 |
+
|
146 |
+
spark_builder = spark_builder.master('local[1]').config('spark.executor.instances', '1')
|
147 |
+
|
148 |
+
self.runtime = spark_builder.appName(conf.get('job_name', 'aworld_spark_job')).getOrCreate()
|
149 |
+
|
150 |
+
def args_process(self, *args):
|
151 |
+
re_args = []
|
152 |
+
for arg in args:
|
153 |
+
if arg:
|
154 |
+
options = self.runtime.sparkContext.broadcast(arg)
|
155 |
+
arg = options.value
|
156 |
+
re_args.append(arg)
|
157 |
+
return re_args
|
158 |
+
|
159 |
+
async def execute(self, funcs: List[Callable[..., Any]], *args, **kwargs) -> Dict[str, Any]:
|
160 |
+
re_args = self.args_process(*args)
|
161 |
+
res_rdd = self.runtime.sparkContext.parallelize(funcs, len(funcs)).map(
|
162 |
+
lambda func: func(*re_args, **kwargs))
|
163 |
+
|
164 |
+
res_list = res_rdd.collect()
|
165 |
+
results = {res.id: res for res in res_list}
|
166 |
+
return results
|
167 |
+
|
168 |
+
|
169 |
+
class RayRuntime(RuntimeEngine):
|
170 |
+
"""Ray runtime key is 'ray', and execute tasks in ray cluster.
|
171 |
+
|
172 |
+
Ray runtime in TaskRuntimeBackend only execute function (stateless), can be used to custom
|
173 |
+
resource allocation and communication etc. advanced features.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, engine_options):
|
177 |
+
super(RayRuntime, self).__init__(engine_options)
|
178 |
+
|
179 |
+
def _build_engine(self):
|
180 |
+
import ray
|
181 |
+
|
182 |
+
if not ray.is_initialized():
|
183 |
+
ray.init()
|
184 |
+
|
185 |
+
self.runtime = ray
|
186 |
+
self.num_executors = self.conf.get('num_executors', 1)
|
187 |
+
logger.info("ray init finished, executor number {}".format(str(self.num_executors)))
|
188 |
+
|
189 |
+
async def execute(self, funcs: List[Callable[..., Any]], *args, **kwargs) -> Dict[str, Any]:
|
190 |
+
@self.runtime.remote
|
191 |
+
def fn_wrapper(fn, *args):
|
192 |
+
if asyncio.iscoroutinefunction(fn):
|
193 |
+
return sync_exec(fn, *args, **kwargs)
|
194 |
+
else:
|
195 |
+
real_args = [arg for arg in args if not isinstance(arg, MethodType)]
|
196 |
+
return fn(*real_args, **kwargs)
|
197 |
+
|
198 |
+
params = []
|
199 |
+
for arg in args:
|
200 |
+
params.append([arg] * len(funcs))
|
201 |
+
|
202 |
+
ray_map = lambda func, fn: [func.remote(x, *y) for x, *y in zip(fn, *params)]
|
203 |
+
res_list = self.runtime.get(ray_map(fn_wrapper, funcs))
|
204 |
+
return {res.id: res for res in res_list}
|
205 |
+
|
206 |
+
|
207 |
+
RUNTIME: Dict[str, RuntimeEngine] = {}
|
208 |
+
|
209 |
+
|
210 |
+
def register(key, runtime_backend):
|
211 |
+
if RUNTIME.get(key, None) is not None:
|
212 |
+
logger.debug("{} runtime backend already exists, will reuse the client.".format(key))
|
213 |
+
return
|
214 |
+
|
215 |
+
RUNTIME[key] = runtime_backend
|
216 |
+
logger.info("register {}:{} success".format(key, runtime_backend))
|
aworld/core/singleton.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
from aworld.logs.util import logger
|
4 |
+
|
5 |
+
import threading
|
6 |
+
|
7 |
+
|
8 |
+
class SingletonMeta(type):
|
9 |
+
_instances = {}
|
10 |
+
_lock = threading.Lock()
|
11 |
+
|
12 |
+
def __call__(cls, *args, **kwargs):
|
13 |
+
"""Create or get the class instance."""
|
14 |
+
with cls._lock:
|
15 |
+
if cls not in cls._instances:
|
16 |
+
instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
|
17 |
+
cls._instances[cls] = instance
|
18 |
+
return cls._instances[cls]
|
19 |
+
|
20 |
+
|
21 |
+
class InheritanceSingleton(object, metaclass=SingletonMeta):
|
22 |
+
_local_instances = {}
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def __get_base_class(clazz):
|
26 |
+
if clazz == object:
|
27 |
+
return None
|
28 |
+
|
29 |
+
bases = clazz.__bases__
|
30 |
+
for base in bases:
|
31 |
+
if base == InheritanceSingleton:
|
32 |
+
return clazz
|
33 |
+
else:
|
34 |
+
base_class = InheritanceSingleton.__get_base_class(base)
|
35 |
+
if base_class:
|
36 |
+
return base_class
|
37 |
+
return None
|
38 |
+
|
39 |
+
def __new__(cls, *args, **kwargs):
|
40 |
+
base = InheritanceSingleton.__get_base_class(cls)
|
41 |
+
if base is None:
|
42 |
+
raise ValueError(f"{cls} singleton base not found")
|
43 |
+
|
44 |
+
return super(InheritanceSingleton, cls).__new__(cls)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def instance(cls, *args, **kwargs):
|
48 |
+
"""Each thread has its own singleton instance."""
|
49 |
+
|
50 |
+
if cls.__name__ not in cls._local_instances:
|
51 |
+
cls._local_instances[cls.__name__] = threading.local()
|
52 |
+
|
53 |
+
local_instance = cls._local_instances[cls.__name__]
|
54 |
+
if not hasattr(local_instance, 'instance'):
|
55 |
+
logger.info(f"{threading.current_thread().name} thread create {cls} instance.")
|
56 |
+
local_instance.instance = cls(*args, **kwargs)
|
57 |
+
|
58 |
+
return local_instance.instance
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def clear_singleton(cls):
|
62 |
+
base = InheritanceSingleton.__get_base_class(cls)
|
63 |
+
cls._instances.pop(base, None)
|
aworld/core/task.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
# Copyright (c) 2025 inclusionAI.
|
3 |
+
import abc
|
4 |
+
import uuid
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from typing import Any, Union, List, Dict, Callable, Optional
|
7 |
+
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from aworld.agents.llm_agent import Agent
|
11 |
+
from aworld.core.agent.swarm import Swarm
|
12 |
+
from aworld.core.common import Config
|
13 |
+
from aworld.core.context.base import Context
|
14 |
+
from aworld.core.tool.base import Tool, AsyncTool
|
15 |
+
from aworld.output.outputs import Outputs, DefaultOutputs
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class Task:
|
20 |
+
id: str = uuid.uuid1().hex
|
21 |
+
name: str = uuid.uuid1().hex
|
22 |
+
user_id: str = None
|
23 |
+
session_id: str = None
|
24 |
+
input: Any = None
|
25 |
+
# task config
|
26 |
+
conf: Config = None
|
27 |
+
# global tool instance
|
28 |
+
tools: List[Union[Tool, AsyncTool]] = field(default_factory=list)
|
29 |
+
# global tool names
|
30 |
+
tool_names: List[str] = field(default_factory=list)
|
31 |
+
# custom tool conf
|
32 |
+
tools_conf: Config = field(default_factory=dict)
|
33 |
+
# custom mcp servers conf
|
34 |
+
mcp_servers_conf: Config = field(default_factory=dict)
|
35 |
+
swarm: Optional[Swarm] = None
|
36 |
+
agent: Optional[Agent] = None
|
37 |
+
event_driven: bool = True
|
38 |
+
# for loop detect
|
39 |
+
endless_threshold: int = 3
|
40 |
+
# task_outputs
|
41 |
+
outputs: Outputs = field(default_factory=DefaultOutputs)
|
42 |
+
# task special runner class, for example: package.XXRunner
|
43 |
+
runner_cls: Optional[str] = None
|
44 |
+
# such as: {"start": ["init_tool", "init_context", ...]}
|
45 |
+
hooks: Dict[str, List[str]] = field(default_factory=dict)
|
46 |
+
# task specified context
|
47 |
+
context: 'Context' = None
|
48 |
+
|
49 |
+
class TaskResponse(BaseModel):
|
50 |
+
id: str
|
51 |
+
answer: str | None
|
52 |
+
usage: Dict[str, Any] | None = None
|
53 |
+
time_cost: float | None = None
|
54 |
+
success: bool = False
|
55 |
+
msg: str | None = None
|
56 |
+
|
57 |
+
|
58 |
+
class Runner(object):
|
59 |
+
__metaclass__ = abc.ABCMeta
|
60 |
+
|
61 |
+
_use_demon: bool = False
|
62 |
+
daemon_target: Callable[..., Any] = None
|
63 |
+
context: Context = None
|
64 |
+
|
65 |
+
async def pre_run(self):
|
66 |
+
pass
|
67 |
+
|
68 |
+
async def post_run(self):
|
69 |
+
pass
|
70 |
+
|
71 |
+
@abc.abstractmethod
|
72 |
+
async def do_run(self, context: Context = None):
|
73 |
+
"""Raise exception if not success."""
|
74 |
+
|
75 |
+
async def _daemon_run(self):
|
76 |
+
if self._use_demon and self.daemon_target and callable(self.daemon_target):
|
77 |
+
import threading
|
78 |
+
t = threading.Thread(target=self.daemon_target, name="daemon", daemon=True)
|
79 |
+
t.start()
|
80 |
+
|
81 |
+
async def run(self) -> Any:
|
82 |
+
try:
|
83 |
+
await self.pre_run()
|
84 |
+
await self._daemon_run()
|
85 |
+
ret = await self.do_run(self.context)
|
86 |
+
return 0 if ret is None else ret
|
87 |
+
except BaseException as ex:
|
88 |
+
self._exception = ex
|
89 |
+
# do record or report
|
90 |
+
raise ex
|
91 |
+
finally:
|
92 |
+
await self.post_run()
|