Duibonduil commited on
Commit
6f8def7
·
verified ·
1 Parent(s): 7214bec

Upload 5 files

Browse files
aworld/memory/README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## multi-agents memory
2
+ ![](../../readme_assets/framework_memory_example.png)
3
+
4
+ ### Short-Term Memory
5
+
6
+ Short-term memory (InMemory) is suitable for lightweight, temporary multi-agent memory scenarios. Data is only stored in memory, making it ideal for testing and small-scale experiments.
7
+
8
+ **Usage Example:**
9
+
10
+ ```python
11
+ from aworld.core.memory import MemoryConfig, MemoryItem
12
+ from aworld.memory.main import MemoryFactory
13
+
14
+ # Create InMemory config
15
+ memory_config = MemoryConfig(provider="inmemory", enable_summary=False)
16
+ # Initialize Memory
17
+ memory = MemoryFactory.from_config(memory_config)
18
+
19
+ # Add a memory item
20
+ memory.add(MemoryItem(content="Hello, world!", metadata={"user_id": "u1"}, tags=["greeting"]))
21
+
22
+ # Get all memory items
23
+ all_memories = memory.get_all()
24
+ for item in all_memories:
25
+ print(item.content)
26
+ ```
27
+
28
+ ### Long-Term Memory
29
+
30
+ Long-term memory (Mem0) is suitable for persistent, vectorized retrieval and summarization in multi-agent scenarios. It supports LLM-based summarization and vector storage.
31
+
32
+ **Usage Example:**
33
+
34
+ ```python
35
+ from aworld.core.memory import MemoryConfig, MemoryItem
36
+ from aworld.memory.main import MemoryFactory
37
+
38
+ # Create Mem0 config (requires mem0 and related dependencies)
39
+ memory_config = MemoryConfig(
40
+ provider="mem0",
41
+ enable_summary=True, # Enable summarization
42
+ summary_rounds=5, # Generate a summary every 5 rounds
43
+ embedder_provider="huggingface", # Embedding model provider
44
+ embedder_model="all-MiniLM-L6-v2", # Embedding model name
45
+ embedder_dims=384
46
+ )
47
+ # Initialize Memory
48
+ memory = MemoryFactory.from_config(memory_config)
49
+
50
+ # Add a memory item
51
+ memory.add(MemoryItem(content="The agent visited Hangzhou.", metadata={"user_id": "u1"}, tags=["travel"]))
52
+
53
+ # Get all memory items
54
+ all_memories = memory.get_all()
55
+ for item in all_memories:
56
+ print(item.content)
57
+ ```
58
+
59
+ > Note: To use mem0, you must install `mem0` and `sentence-transformers` in advance, and configure the required LLM environment variables.
60
+
61
+ ### CheckPoint
62
+ TODO
aworld/memory/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
aworld/memory/main.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ import abc
4
+ import asyncio
5
+ import json
6
+ import os
7
+ from typing import Optional
8
+
9
+ from aworld.config import ConfigDict
10
+ from aworld.core.memory import MemoryBase, MemoryItem, MemoryStore, MemoryConfig
11
+ from aworld.logs.util import logger
12
+ from aworld.models.llm import get_llm_model, acall_llm_model
13
+
14
+
15
+ class InMemoryMemoryStore(MemoryStore):
16
+ def __init__(self):
17
+ self.memory_items = []
18
+
19
+ def add(self, memory_item: MemoryItem):
20
+ self.memory_items.append(memory_item)
21
+
22
+ def get(self, memory_id) -> Optional[MemoryItem]:
23
+ return next((item for item in self.memory_items if item.id == memory_id), None)
24
+
25
+ def get_first(self, filters: dict = None) -> Optional[MemoryItem]:
26
+ """Get the first memory item."""
27
+ filtered_items = self.get_all(filters)
28
+ if len(filtered_items) == 0:
29
+ return None
30
+ return filtered_items[0]
31
+
32
+ def total_rounds(self, filters: dict = None) -> int:
33
+ """Get the total number of rounds."""
34
+ return len(self.get_all(filters))
35
+
36
+ def get_all(self, filters: dict = None) -> list[MemoryItem]:
37
+ """Filter memory items based on filters."""
38
+ filtered_items = [item for item in self.memory_items if self._filter_memory_item(item, filters)]
39
+ return filtered_items
40
+
41
+ def _filter_memory_item(self, memory_item: MemoryItem, filters: dict = None) -> bool:
42
+ if memory_item.deleted:
43
+ return False
44
+ if filters is None:
45
+ return True
46
+ if filters.get('user_id') is not None:
47
+ if memory_item.metadata.get('user_id') is None:
48
+ return False
49
+ if memory_item.metadata.get('user_id') != filters['user_id']:
50
+ return False
51
+ if filters.get('agent_id') is not None:
52
+ if memory_item.metadata.get('agent_id') is None:
53
+ return False
54
+ if memory_item.metadata.get('agent_id') != filters['agent_id']:
55
+ return False
56
+ if filters.get('task_id') is not None:
57
+ if memory_item.metadata.get('task_id') is None:
58
+ return False
59
+ if memory_item.metadata.get('task_id') != filters['task_id']:
60
+ return False
61
+ if filters.get('session_id') is not None:
62
+ if memory_item.metadata.get('session_id') is None:
63
+ return False
64
+ if memory_item.metadata.get('session_id') != filters['session_id']:
65
+ return False
66
+ if filters.get('memory_type') is not None:
67
+ if memory_item.memory_type is None:
68
+ return False
69
+ if memory_item.memory_type != filters['memory_type']:
70
+ return False
71
+ return True
72
+
73
+ def get_last_n(self, last_rounds, filters: dict = None) -> list[MemoryItem]:
74
+ return self.memory_items[-last_rounds:] # Get the last n items
75
+
76
+ def update(self, memory_item: MemoryItem):
77
+ for index, item in enumerate(self.memory_items):
78
+ if item.id == memory_item.id:
79
+ self.memory_items[index] = memory_item # Update the item in the list
80
+ break
81
+
82
+ def delete(self, memory_id):
83
+ exists = self.get(memory_id)
84
+ if exists:
85
+ exists.deleted = True
86
+
87
+ def history(self, memory_id) -> list[MemoryItem] | None:
88
+ exists = self.get(memory_id)
89
+ if exists:
90
+ return exists.histories
91
+ return None
92
+
93
+
94
+ class MemoryFactory:
95
+
96
+ @classmethod
97
+ def from_config(cls, config: MemoryConfig) -> "MemoryBase":
98
+ """
99
+ Initialize a Memory instance from a configuration dictionary.
100
+
101
+ Args:
102
+ config (dict): Configuration dictionary.
103
+
104
+ Returns:
105
+ InMemoryStorageMemory: Memory instance.
106
+ """
107
+ if config.provider == "inmemory":
108
+ return InMemoryStorageMemory(
109
+ memory_store=InMemoryMemoryStore(),
110
+ config=config,
111
+ enable_summary=config.enable_summary,
112
+ summary_rounds=config.summary_rounds
113
+ )
114
+ elif config.provider == "mem0":
115
+ from aworld.memory.mem0.mem0_memory import Mem0Memory
116
+ return Mem0Memory(
117
+ memory_store=InMemoryMemoryStore(),
118
+ config=config
119
+ )
120
+ else:
121
+ raise ValueError(f"Invalid memory store type: {config.get('memory_store')}")
122
+
123
+
124
+ class Memory(MemoryBase):
125
+ __metaclass__ = abc.ABCMeta
126
+
127
+ def __init__(self, memory_store: MemoryStore, config: MemoryConfig, **kwargs):
128
+ self.memory_store = memory_store
129
+ self.config = config
130
+ self._llm_instance = None
131
+
132
+ @property
133
+ def default_llm_instance(self):
134
+ def get_env(key: str, default_key: str, default_val: object=None):
135
+ return os.getenv(key) if os.getenv(key) else os.getenv(default_key, default_val)
136
+
137
+ if not self._llm_instance:
138
+ self._llm_instance = get_llm_model(conf=ConfigDict({
139
+ "llm_model_name": get_env("MEM_LLM_MODEL_NAME", "LLM_MODEL_NAME"),
140
+ "llm_api_key": get_env("MEM_LLM_API_KEY", "LLM_MODEL_NAME") ,
141
+ "llm_base_url": get_env("MEM_LLM_BASE_URL", 'LLM_BASE_URL'),
142
+ "temperature": get_env("MEM_LLM_TEMPERATURE", "MEM_LLM_TEMPERATURE", 1.0),
143
+ "streaming": 'False'
144
+ }))
145
+ return self._llm_instance
146
+
147
+ def _build_history_context(self, messages) -> str:
148
+ """Build the history context string from a list of messages.
149
+
150
+ Args:
151
+ messages: List of message objects with 'role', 'content', and optional 'tool_calls'.
152
+ Returns:
153
+ Concatenated context string.
154
+ """
155
+ history_context = ""
156
+ for item in messages:
157
+ history_context += (f"\n\n{item['role']}: {item['content']}, "
158
+ f"{'tool_calls:' + json.dumps(item['tool_calls']) if 'tool_calls' in item and item['tool_calls'] else ''}")
159
+ return history_context
160
+
161
+ async def _call_llm_summary(self, summary_messages: list) -> str:
162
+ """Call LLM to generate summary and log the process.
163
+
164
+ Args:
165
+ summary_messages: List of messages to send to LLM.
166
+ Returns:
167
+ Summary content string.
168
+ """
169
+ logger.info(f"🤔 [Summary] Creating summary memory, history messages: {summary_messages}")
170
+ llm_response = await acall_llm_model(
171
+ self.default_llm_instance,
172
+ messages=summary_messages,
173
+ stream=False
174
+ )
175
+ logger.info(f'🤔 [Summary] summary_content: result is {llm_response.content[:400] + "...truncated"} ')
176
+ return llm_response.content
177
+
178
+ def _get_parsed_history_messages(self, history_items: list[MemoryItem]) -> list[dict]:
179
+ """Get and format history messages for summary.
180
+
181
+ Args:
182
+ history_items: list[MemoryItem]
183
+ Returns:
184
+ List of parsed message dicts
185
+ """
186
+ parsed_messages = [
187
+ {
188
+ 'role': message.metadata['role'],
189
+ 'content': message.content,
190
+ 'tool_calls': message.metadata.get('tool_calls') if message.metadata.get('tool_calls') else None
191
+ }
192
+ for message in history_items]
193
+ return parsed_messages
194
+
195
+ async def async_gen_multi_rounds_summary(self, to_be_summary: list[MemoryItem]) -> str:
196
+ logger.info(
197
+ f"🤔 [Summary] Creating summary memory, history messages")
198
+ if len(to_be_summary) == 0:
199
+ return ""
200
+ parsed_messages = self._get_parsed_history_messages(to_be_summary)
201
+ history_context = self._build_history_context(parsed_messages)
202
+
203
+ summary_messages = [
204
+ {"role": "user", "content": self.config.summary_prompt.format(context=history_context)}
205
+ ]
206
+
207
+ return await self._call_llm_summary(summary_messages)
208
+
209
+ async def async_gen_summary(self, filters: dict, last_rounds: int) -> str:
210
+ """A tool for summarizing the conversation history."""
211
+
212
+ logger.info(f"🤔 [Summary] Creating summary memory, history messages [filters -> {filters}, "
213
+ f"last_rounds -> {last_rounds}]")
214
+ history_items = self.memory_store.get_last_n(last_rounds, filters=filters)
215
+ if len(history_items) == 0:
216
+ return ""
217
+ parsed_messages = self._get_parsed_history_messages(history_items)
218
+ history_context = self._build_history_context(parsed_messages)
219
+
220
+ summary_messages = [
221
+ {"role": "user", "content": self.config.summary_prompt.format(context=history_context)}
222
+ ]
223
+
224
+ return await self._call_llm_summary(summary_messages)
225
+
226
+ async def async_gen_cur_round_summary(self, to_be_summary: MemoryItem, filters: dict, last_rounds: int) -> str:
227
+ if self.config.enable_summary and len(to_be_summary.content) < self.config.summary_single_context_length:
228
+ return to_be_summary.content
229
+
230
+ logger.info(f"🤔 [Summary] Creating summary memory, history messages [filters -> {filters}, "
231
+ f"last_rounds -> {last_rounds}]: to be summary content is {to_be_summary.content}")
232
+ history_items = self.memory_store.get_last_n(last_rounds, filters=filters)
233
+ if len(history_items) == 0:
234
+ return ""
235
+ parsed_messages = self._get_parsed_history_messages(history_items)
236
+
237
+ # Append the to_be_summary
238
+ parsed_messages.append({
239
+ "role": to_be_summary.metadata['role'],
240
+ "content": f"{to_be_summary.content}",
241
+ 'tool_call_id': to_be_summary.metadata['tool_call_id'],
242
+ })
243
+ history_context = self._build_history_context(parsed_messages)
244
+
245
+ summary_messages = [
246
+ {"role": "user", "content": self.config.summary_prompt.format(context=history_context)}
247
+ ]
248
+
249
+ return await self._call_llm_summary(summary_messages)
250
+
251
+ def search(self, query, limit=100, filters=None) -> Optional[list[MemoryItem]]:
252
+ pass
253
+
254
+
255
+ class InMemoryStorageMemory(Memory):
256
+ def __init__(self, memory_store: MemoryStore, config: MemoryConfig, enable_summary: bool = True, **kwargs):
257
+ super().__init__(memory_store=memory_store, config=config)
258
+ self.summary = {}
259
+ self.summary_rounds = self.config.summary_rounds
260
+ self.enable_summary = self.config.enable_summary
261
+
262
+ def add(self, memory_item: MemoryItem, filters: dict = None):
263
+ self.memory_store.add(memory_item)
264
+
265
+ # Check if we need to create or update summary
266
+ if self.enable_summary:
267
+ total_rounds = len(self.memory_store.get_all())
268
+ if total_rounds > self.summary_rounds:
269
+ self._create_or_update_summary(total_rounds)
270
+
271
+ def _create_or_update_summary(self, total_rounds: int):
272
+ """Create or update summary based on current total rounds.
273
+
274
+ Args:
275
+ total_rounds (int): Total number of rounds.
276
+ """
277
+ summary_index = int(total_rounds / self.summary_rounds)
278
+ start = (summary_index - 1) * self.summary_rounds
279
+ end = total_rounds - self.summary_rounds
280
+
281
+ # Ensure we have valid start and end indices
282
+ start = max(0, start)
283
+ end = max(start, end)
284
+
285
+ # Get the memory items to summarize
286
+ items_to_summarize = self.memory_store.get_all()[start:end + 1]
287
+ print(f"{total_rounds}start: {start}, end: {end},")
288
+
289
+ # Create summary content
290
+ summary_content = self._summarize_items(items_to_summarize, summary_index)
291
+
292
+ # Create the range key
293
+ range_key = f"{start}_{end}"
294
+
295
+ # Check if summary for this range already exists
296
+ if range_key in self.summary:
297
+ # Update existing summary
298
+ self.summary[range_key].content = summary_content
299
+ self.summary[range_key].updated_at = None # This will update the timestamp
300
+ else:
301
+ # Create new summary
302
+ summary_item = MemoryItem(
303
+ content=summary_content,
304
+ metadata={
305
+ "summary_index": summary_index,
306
+ "start_round": start,
307
+ "end_round": end,
308
+ "role": "system"
309
+ },
310
+ tags=["summary"]
311
+ )
312
+ self.summary[range_key] = summary_item
313
+
314
+ def _summarize_items(self, items: list[MemoryItem], summary_index: int) -> str:
315
+ """Summarize a list of memory items.
316
+
317
+ Args:
318
+ items (list[MemoryItem]): List of memory items to summarize.
319
+ summary_index (int): Summary index.
320
+
321
+ Returns:
322
+ str: Summary content.
323
+ """
324
+ # This is a placeholder. In a real implementation, you might use an LLM or other method
325
+ # to create a meaningful summary of the content
326
+ return asyncio.run(self.async_gen_multi_rounds_summary(items))
327
+
328
+ def update(self, memory_item: MemoryItem):
329
+ self.memory_store.update(memory_item)
330
+
331
+ def delete(self, memory_id):
332
+ self.memory_store.delete(memory_id)
333
+
334
+ def get(self, memory_id) -> Optional[MemoryItem]:
335
+ return self.memory_store.get(memory_id)
336
+
337
+ def get_all(self, filters: dict = None) -> list[MemoryItem]:
338
+ return self.memory_store.get_all()
339
+
340
+ def get_last_n(self, last_rounds, add_first_message=True, filters: dict = None) -> list[MemoryItem]:
341
+ """Get last n memories.
342
+
343
+ Args:
344
+ last_rounds (int): Number of memories to retrieve.
345
+ add_first_message (bool):
346
+
347
+ Returns:
348
+ list[MemoryItem]: List of latest memories.
349
+ """
350
+ memory_items = self.memory_store.get_last_n(last_rounds)
351
+ while len(memory_items) > 0 and memory_items[0].metadata and "tool_call_id" in memory_items[0].metadata and \
352
+ memory_items[0].metadata["tool_call_id"]:
353
+ last_rounds = last_rounds + 1
354
+ memory_items = self.memory_store.get_last_n(last_rounds)
355
+
356
+ # If summary is disabled or no summaries exist, return just the last_n_items
357
+ if not self.enable_summary or not self.summary:
358
+ return memory_items
359
+
360
+ # Calculate the range for relevant summaries
361
+ all_items = self.memory_store.get_all()
362
+ total_items = len(all_items)
363
+ end_index = total_items - last_rounds
364
+
365
+ # Get complete summaries
366
+ result = []
367
+ complete_summary_count = end_index // self.summary_rounds
368
+
369
+ # Get complete summaries
370
+ for i in range(complete_summary_count):
371
+ range_key = f"{i * self.summary_rounds}_{(i + 1) * self.summary_rounds - 1}"
372
+ if range_key in self.summary:
373
+ result.append(self.summary[range_key])
374
+
375
+ # Get the last incomplete summary if exists
376
+ remaining_items = end_index % self.summary_rounds
377
+ if remaining_items > 0:
378
+ start = complete_summary_count * self.summary_rounds
379
+ range_key = f"{start}_{end_index - 1}"
380
+ if range_key in self.summary:
381
+ result.append(self.summary[range_key])
382
+
383
+ # Add the last n items
384
+ result.extend(memory_items)
385
+
386
+ # Add first user input
387
+ if add_first_message and last_rounds < self.memory_store.total_rounds():
388
+ memory_items.insert(0, self.memory_store.get_first())
389
+
390
+ return result
aworld/memory/models.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, ConfigDict, Field
2
+ from aworld.core.memory import MemoryItem
3
+ from typing import Any, Dict, List, Optional, Literal
4
+
5
+ from aworld.models.model_response import ToolCall
6
+
7
+ class MessageMetadata(BaseModel):
8
+ """
9
+ Metadata for memory messages, including user, session, task, and agent information.
10
+ Args:
11
+ user_id (str): The ID of the user.
12
+ session_id (str): The ID of the session.
13
+ task_id (str): The ID of the task.
14
+ agent_id (str): The ID of the agent.
15
+ """
16
+ user_id: str = Field(description="The ID of the user")
17
+ session_id: str = Field(description="The ID of the session")
18
+ task_id: str = Field(description="The ID of the task")
19
+ agent_id: str = Field(description="The ID of the agent")
20
+ agent_name: str = Field(description="The name of the agent")
21
+
22
+ model_config = ConfigDict(extra="allow")
23
+
24
+ @property
25
+ def to_dict(self) -> Dict[str, Any]:
26
+ return self.model_dump()
27
+
28
+ class AgentExperienceItem(BaseModel):
29
+ skill: str = Field(description="The skill demonstrated in the experience")
30
+ actions: List[str] = Field(description="The actions taken by the agent")
31
+
32
+ class AgentExperience(MemoryItem):
33
+ """
34
+ Represents an agent's experience, including skills and actions.
35
+ All custom attributes are stored in content and metadata.
36
+ Args:
37
+ agent_id (str): The ID of the agent.
38
+ skill (str): The skill demonstrated in the experience.
39
+ actions (List[str]): The actions taken by the agent.
40
+ metadata (Optional[Dict[str, Any]]): Additional metadata.
41
+ """
42
+ def __init__(self, agent_id: str, skill: str, actions: List[str], metadata: Optional[Dict[str, Any]] = None) -> None:
43
+ meta = metadata.copy() if metadata else {}
44
+ meta['agent_id'] = agent_id
45
+ agent_experience = AgentExperienceItem(skill=skill, actions=actions)
46
+ super().__init__(content=agent_experience, metadata=meta, memory_type="agent_experience")
47
+
48
+ @property
49
+ def agent_id(self) -> str:
50
+ return self.metadata['agent_id']
51
+
52
+ @property
53
+ def skill(self) -> str:
54
+ return self.content.skill
55
+
56
+ @property
57
+ def actions(self) -> List[str]:
58
+ return self.content.actions
59
+
60
+ class UserProfileItem(BaseModel):
61
+ key: str = Field(description="The key of the profile")
62
+ value: Any = Field(description="The value of the profile")
63
+
64
+ class UserProfile(MemoryItem):
65
+ """
66
+ Represents a user profile key-value pair.
67
+ All custom attributes are stored in content and metadata.
68
+ Args:
69
+ user_id (str): The ID of the user.
70
+ key (str): The profile key.
71
+ value (Any): The profile value.
72
+ metadata (Optional[Dict[str, Any]]): Additional metadata.
73
+ """
74
+ def __init__(self, user_id: str, key: str, value: Any, metadata: Optional[Dict[str, Any]] = None) -> None:
75
+ meta = metadata.copy() if metadata else {}
76
+ meta['user_id'] = user_id
77
+ user_profile = UserProfileItem(key=key, value=value)
78
+ super().__init__(content=user_profile, metadata=meta, memory_type="user_profile")
79
+
80
+ @property
81
+ def user_id(self) -> str:
82
+ return self.metadata['user_id']
83
+
84
+ @property
85
+ def key(self) -> str:
86
+ return self.content.key
87
+
88
+ @property
89
+ def value(self) -> Any:
90
+ return self.content.value
91
+
92
+ class MemoryMessage(MemoryItem):
93
+ """
94
+ Represents a memory message with role, user, session, task, and agent information.
95
+ Args:
96
+ role (str): The role of the message sender.
97
+ metadata (MessageMetadata): Metadata object containing user, session, task, and agent IDs.
98
+ content (Optional[Any]): Content of the message.
99
+ """
100
+ def __init__(self, role: str, metadata: MessageMetadata, content: Optional[Any] = None) -> None:
101
+ meta = metadata.to_dict
102
+ meta['role'] = role
103
+ super().__init__(content=content, metadata=meta, memory_type="message")
104
+
105
+ @property
106
+ def role(self) -> str:
107
+ return self.metadata['role']
108
+
109
+ @property
110
+ def user_id(self) -> str:
111
+ return self.metadata['user_id']
112
+
113
+ @property
114
+ def session_id(self) -> str:
115
+ return self.metadata['session_id']
116
+
117
+ @property
118
+ def task_id(self) -> str:
119
+ return self.metadata['task_id']
120
+
121
+ @property
122
+ def agent_id(self) -> str:
123
+ return self.metadata['agent_id']
124
+
125
+ class SystemMessage(MemoryMessage):
126
+ """
127
+ Represents a system message with role and content.
128
+ Args:
129
+ metadata (MessageMetadata): Metadata object containing user, session, task, and agent IDs.
130
+ content (str): The content of the message.
131
+ """
132
+ def __init__(self, content: str, metadata: MessageMetadata) -> None:
133
+ super().__init__(role="system", metadata=metadata, content=content)
134
+
135
+ @property
136
+ def content(self) -> str:
137
+ return self._content
138
+
139
+ class HumanMessage(MemoryMessage):
140
+ """
141
+ Represents a human message with role and content.
142
+ Args:
143
+ metadata (MessageMetadata): Metadata object containing user, session, task, and agent IDs.
144
+ content (str): The content of the message.
145
+ """
146
+ def __init__(self, metadata: MessageMetadata, content: str) -> None:
147
+ super().__init__(role="human", metadata=metadata, content=content)
148
+
149
+ @property
150
+ def content(self) -> str:
151
+ return self._content
152
+
153
+ class AIMessage(MemoryMessage):
154
+ """
155
+ Represents an AI message with role and content.
156
+ Args:
157
+ metadata (MessageMetadata): Metadata object containing user, session, task, and agent IDs.
158
+ content (str): The content of the message.
159
+ """
160
+ def __init__(self, content: str, tool_calls: List[ToolCall], metadata: MessageMetadata) -> None:
161
+ meta = metadata.to_dict
162
+ meta['tool_calls'] = [tool_call.to_dict() for tool_call in tool_calls]
163
+ super().__init__(role="assistant", metadata=MessageMetadata(**meta), content=content)
164
+
165
+ @property
166
+ def content(self) -> str:
167
+ return self._content
168
+
169
+ @property
170
+ def tool_calls(self) -> List[ToolCall]:
171
+ return [ToolCall(**tool_call) for tool_call in self.metadata['tool_calls']]
172
+
173
+ class ToolMessage(MemoryMessage):
174
+ """
175
+ Represents a tool message with role, content, tool_call_id, and status.
176
+ Args:
177
+ metadata (MessageMetadata): Metadata object containing user, session, task, and agent IDs.
178
+ tool_call_id (str): The ID of the tool call.
179
+ status (Literal["success", "error"]): The status of the tool call.
180
+ content (str): The content of the message.
181
+ """
182
+ def __init__(self, tool_call_id: str, content: str, status: Literal["success", "error"] = "success", metadata: MessageMetadata = None) -> None:
183
+ metadata.tool_call_id = tool_call_id
184
+ metadata.status = status
185
+ super().__init__(role="tool", metadata=metadata, content=content)
186
+
187
+ @property
188
+ def tool_call_id(self) -> str:
189
+ return self.metadata['tool_call_id']
190
+
191
+ @property
192
+ def status(self) -> str:
193
+ return self.metadata['status']
194
+
195
+ @property
196
+ def content(self) -> str:
197
+ return self._content
aworld/memory/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from aworld.logs.util import logger
3
+
4
+ # TODO: merge to `models` package
5
+
6
+ MODEL_TO_ENCODING = {
7
+ "gpt-3.5-turbo": "cl100k_base",
8
+ "gpt-4": "cl100k_base",
9
+ "text-davinci-003": "p50k_base",
10
+ "text-embedding-ada-002": "cl100k_base",
11
+ "text-curie-001": "r50k_base",
12
+ "text-babbage-001": "r50k_base",
13
+ "text-ada-001": "r50k_base",
14
+ }
15
+
16
+ def get_encoding_for_model(model_name: str) -> tiktoken.Encoding:
17
+ """
18
+ Automatically select the corresponding encoder based on the model name.
19
+ """
20
+ encoding_name = MODEL_TO_ENCODING.get(model_name)
21
+ if encoding_name is None:
22
+ logger.warning(f"model '{model_name}' not found in mapping table.")
23
+ return "cl100k_base"
24
+ return encoding_name
25
+
26
+ def count_tokens(model_name: str, content: str):
27
+ encoding = tiktoken.get_encoding(get_encoding_for_model(model_name))
28
+
29
+ tokens = encoding.encode(content)
30
+
31
+ token_count = len(tokens)
32
+
33
+ return token_count