Duibonduil commited on
Commit
7214bec
·
verified ·
1 Parent(s): c0e84de

Upload mem0_memory.py

Browse files
Files changed (1) hide show
  1. aworld/memory/mem0/mem0_memory.py +193 -0
aworld/memory/mem0/mem0_memory.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import traceback
4
+ from typing import Optional
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from aworld.config import ConfigDict
9
+ from aworld.core.memory import MemoryStore, MemoryConfig, MemoryItem
10
+ from aworld.logs.util import logger
11
+ from aworld.memory.main import Memory
12
+ from aworld.models.llm import get_llm_model
13
+
14
+
15
+ class Mem0Memory(Memory):
16
+ def __init__(self, memory_store: MemoryStore, config: MemoryConfig | None = None, **kwargs):
17
+ super().__init__(memory_store, config, **kwargs)
18
+ self.config = config
19
+
20
+ conf = ConfigDict(
21
+ llm_provider=config.llm_provider,
22
+ llm_model_name=os.getenv("MEM_LLM_MODEL_NAME") if os.getenv("MEM_LLM_MODEL_NAME") else os.getenv(
23
+ 'LLM_MODEL_NAME'),
24
+ llm_temperature=os.getenv("MEM_LLM_TEMPERATURE") if os.getenv("MEM_LLM_TEMPERATURE") else 1.0,
25
+ llm_base_url=os.getenv("MEM_LLM_BASE_URL") if os.getenv("MEM_LLM_BASE_URL") else os.getenv('LLM_BASE_URL'),
26
+ llm_api_key=os.getenv("MEM_LLM_API_KEY") if os.getenv("MEM_LLM_API_KEY") else os.getenv('LLM_API_KEY')
27
+ )
28
+ self.config.llm_instance = get_llm_model(conf=conf, streaming=False)
29
+
30
+ # Check for required packages
31
+ try:
32
+ # also disable mem0's telemetry when ANONYMIZED_TELEMETRY=False
33
+ if os.getenv('ANONYMIZED_TELEMETRY', 'true').lower()[0] in 'fn0':
34
+ os.environ['MEM_TELEMETRY'] = 'False'
35
+ from mem0 import Memory as Mem0
36
+ except ImportError:
37
+ raise ImportError('mem0 is required when enable_memory=True. Please install it with `pip install mem0`.')
38
+
39
+ if self.config.embedder_provider == 'huggingface':
40
+ try:
41
+ # check that required package is installed if huggingface is used
42
+ from sentence_transformers import SentenceTransformer # noqa: F401
43
+ except ImportError:
44
+ raise ImportError(
45
+ 'sentence_transformers is required when enable_memory=True and embedder_provider="huggingface". Please install it with `pip install sentence-transformers`.'
46
+ )
47
+
48
+ # Initialize Mem0 with the configuration
49
+ config_dict = self.config.full_config_dict
50
+ self.mem0 = Mem0.from_config(config_dict=self.config.full_config_dict)
51
+ self.memory_store = memory_store
52
+
53
+ def add(self, memory_item: MemoryItem, filters: dict = None):
54
+ # generate summary memory if needed
55
+ message_filters = {
56
+ "memory_type": "message"
57
+ }
58
+ if filters:
59
+ message_filters = {
60
+ "memory_type": "message",
61
+ "agent_id": memory_item.metadata.get("agent_id"),
62
+ "task_id": memory_item.metadata.get("task_id"),
63
+ "user_id": memory_item.metadata.get("user_id"),
64
+ "session_id": memory_item.metadata.get("session_id"),
65
+ }
66
+ if self._need_summary(memory_item, message_filters):
67
+ self.create_summary_memory(
68
+ agent_id=memory_item.metadata.get("agent_id"),
69
+ task_id=memory_item.metadata.get("task_id"),
70
+ user_id=memory_item.metadata.get("user_id"),
71
+ session_id=memory_item.metadata.get("session_id"),
72
+ filters=message_filters
73
+ )
74
+ self.memory_store.add(memory_item)
75
+
76
+ def _need_summary(self, memory_item, message_filters):
77
+ """
78
+ Check if a summary is needed based on the current step.
79
+ 1. If the number of messages is greater than the summary rounds.
80
+ 2. If the message is a message and the content is greater than the summary single context length.
81
+ """
82
+ return self.memory_store.total_rounds(message_filters) > self.config.summary_rounds or (
83
+ memory_item.memory_type == 'message' and len(
84
+ memory_item.content) >= self.config.summary_single_context_length)
85
+
86
+ def create_summary_memory(self, agent_id, task_id, user_id, session_id, filters: dict) -> None:
87
+ """
88
+ Create a summary memory if needed based on the current step.
89
+ """
90
+ logger.info(f'Creating summary memory, {filters}')
91
+
92
+ # Get all messages
93
+ all_messages = self.memory_store.get_all(filters=filters)
94
+
95
+ # Separate messages into those to keep as-is and those to process for memory
96
+ summary_messages = []
97
+ messages_to_process = []
98
+
99
+ for msg in all_messages:
100
+ if isinstance(msg, MemoryItem) and msg.memory_type in {'summary'}:
101
+ # Keep system and memory messages as they are
102
+ summary_messages.append(msg)
103
+ elif msg.memory_type in {'init'}:
104
+ messages_to_process.append(msg)
105
+ else:
106
+ if len(msg.content) > 0:
107
+ messages_to_process.append(msg)
108
+ if messages_to_process[-1].metadata.get("tool_calls"):
109
+ messages_to_process = messages_to_process[:-1]
110
+ # Need at least 1 message to create a meaningful summary
111
+ if len(messages_to_process) < 1:
112
+ logger.info('Not enough non-memory messages to summarize')
113
+ return
114
+ # Create a procedural memory
115
+
116
+ memory_content = self._create_summary_memory(messages_to_process)
117
+
118
+ if not memory_content:
119
+ logger.warning('Failed to create procedural memory')
120
+ return
121
+
122
+ # Add the summary message
123
+ summary_message = MemoryItem(content=memory_content, memory_type='summary', metadata={
124
+ "role": "user",
125
+ "agent_id": agent_id,
126
+ "session_id": session_id,
127
+ "task_id": task_id,
128
+ "user_id": user_id,
129
+ })
130
+ summary_messages.append(summary_message)
131
+
132
+ # Update the history
133
+ [self.memory_store.delete(m.id) for m in messages_to_process]
134
+ self.memory_store.add(summary_message)
135
+
136
+ logger.info(f'Messages consolidated: {len(messages_to_process)} messages converted to procedural memory')
137
+
138
+ def _create_summary_memory(self, messages: list[MemoryItem]) -> str | None:
139
+
140
+ parsed_messages = [{'role': message.metadata['role'], 'content': message.content if not message.metadata.get(
141
+ 'tool_calls') else message.content + "\n\n" + self.__format_tool_call(message.metadata.get('tool_calls'))}
142
+ for message in
143
+ messages] # TODO add tool_call from metadata['tool_calls'] such as [{"id": "fc-7b66b01a-f125-44d5-9f32-5e3723384d8e", "type": "function", "function": {"name": "mcp__amap-amap-sse__maps_geo", "arguments": "{\"address\": \"\u676d\u5dde\", \"city\": \"\u676d\u5dde\"}"}}] append to content
144
+ try:
145
+ results = self.mem0.add(
146
+ messages=parsed_messages,
147
+ agent_id=messages[-1].metadata.get('agent_id'),
148
+ memory_type='procedural_memory'
149
+ )
150
+ if len(results.get('results', [])):
151
+ logger.info(f'creating summary memory result: {results}')
152
+ return results.get('results', [])[0].get('memory')
153
+ return None
154
+ except Exception as e:
155
+ logger.error(f'Error creating summary memory: {e}')
156
+ traceback.print_exc()
157
+ return None
158
+
159
+ def __format_tool_call(self, tool_calls):
160
+ return json.dumps(tool_calls, default=lambda o: o.model_dump_json() if isinstance(o, BaseModel) else str(o))
161
+
162
+ def update(self, memory_item: MemoryItem):
163
+ self.memory_store.update(memory_item)
164
+
165
+ def delete(self, memory_id):
166
+ self.memory_store.delete(memory_id)
167
+
168
+ def get(self, memory_id) -> Optional[MemoryItem]:
169
+ # self.memory_store.get(memory_id)
170
+ return self.memory_store.get(
171
+ memory_id,
172
+ )
173
+
174
+ def get_all(self, filters: dict = None) -> list[MemoryItem]:
175
+ return self.memory_store.get_all(
176
+ filters=filters,
177
+ )
178
+
179
+ def get_last_n(self, last_rounds, add_first_message=True, filters: dict = None) -> list[MemoryItem]:
180
+ """
181
+ Get last n memories.
182
+
183
+ Args:
184
+ last_rounds (int): Number of memories to retrieve.
185
+ add_first_message (bool):
186
+
187
+ Returns:
188
+ list[MemoryItem]: List of latest memories.
189
+ """
190
+ return self.memory_store.get_last_n(
191
+ last_rounds=last_rounds,
192
+ filters=filters,
193
+ )