Spaces:
Runtime error
Runtime error
Update memory_logic.py
Browse files- memory_logic.py +14 -29
memory_logic.py
CHANGED
@@ -149,7 +149,7 @@ def initialize_memory_system():
|
|
149 |
raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
|
150 |
except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
|
151 |
|
152 |
-
rule_index, valid_rules = _build_faiss_index(
|
153 |
_faiss_rules_index = rule_index
|
154 |
_rules_items_list = valid_rules
|
155 |
logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
|
@@ -160,18 +160,6 @@ def initialize_memory_system():
|
|
160 |
else:
|
161 |
logger.error("Memory system initialization failed. Core components are not ready.")
|
162 |
|
163 |
-
def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
|
164 |
-
global _memory_items_list, _rules_items_list
|
165 |
-
if not index or index.ntotal != len(items_list):
|
166 |
-
logger.warning(f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). Rebuilding...")
|
167 |
-
new_index, valid_items = _build_faiss_index(items_list, text_extraction_fn)
|
168 |
-
if items_list is _memory_items_list:
|
169 |
-
_memory_items_list = valid_items
|
170 |
-
elif items_list is _rules_items_list:
|
171 |
-
_rules_items_list = valid_items
|
172 |
-
return new_index
|
173 |
-
return index
|
174 |
-
|
175 |
def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
|
176 |
global _memory_items_list, _faiss_memory_index
|
177 |
if not _initialized: initialize_memory_system()
|
@@ -197,12 +185,8 @@ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple
|
|
197 |
return False, f"Error: {e}"
|
198 |
|
199 |
def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
|
200 |
-
global _faiss_memory_index
|
201 |
if not _initialized: initialize_memory_system()
|
202 |
if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
|
203 |
-
|
204 |
-
_faiss_memory_index = _verify_and_rebuild_if_needed(_faiss_memory_index, _memory_items_list, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
|
205 |
-
if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
|
206 |
|
207 |
try:
|
208 |
query_embedding = _embedder.encode([query], convert_to_tensor=False)
|
@@ -225,7 +209,6 @@ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
|
|
225 |
embedding = _embedder.encode([rule_text], convert_to_tensor=False)
|
226 |
_faiss_rules_index.add(np.array(embedding, dtype=np.float32))
|
227 |
_rules_items_list.append(rule_text)
|
228 |
-
_rules_items_list.sort()
|
229 |
|
230 |
if STORAGE_BACKEND == "SQLITE":
|
231 |
with _get_sqlite_connection() as conn: conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
|
@@ -237,13 +220,9 @@ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
|
|
237 |
return False, f"Error: {e}"
|
238 |
|
239 |
def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
|
240 |
-
global _faiss_rules_index
|
241 |
if not _initialized: initialize_memory_system()
|
242 |
if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
|
243 |
|
244 |
-
_faiss_rules_index = _verify_and_rebuild_if_needed(_faiss_rules_index, _rules_items_list, lambda r: r)
|
245 |
-
if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
|
246 |
-
|
247 |
try:
|
248 |
query_embedding = _embedder.encode([query], convert_to_tensor=False)
|
249 |
distances, indices = _faiss_rules_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_rules_index.ntotal))
|
@@ -256,17 +235,23 @@ def remove_rule_entry(rule_text_to_delete: str) -> bool:
|
|
256 |
global _rules_items_list, _faiss_rules_index
|
257 |
if not _initialized: initialize_memory_system()
|
258 |
rule_text_to_delete = rule_text_to_delete.strip()
|
259 |
-
|
260 |
try:
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
265 |
|
266 |
if STORAGE_BACKEND == "SQLITE":
|
267 |
with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
|
268 |
-
elif STORAGE_BACKEND == "HF_DATASET"
|
|
|
|
|
269 |
Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
|
|
|
270 |
return True
|
271 |
except Exception as e:
|
272 |
logger.error(f"Error removing rule: {e}", exc_info=True)
|
@@ -274,7 +259,7 @@ def remove_rule_entry(rule_text_to_delete: str) -> bool:
|
|
274 |
|
275 |
def get_all_rules_cached() -> list[str]:
|
276 |
if not _initialized: initialize_memory_system()
|
277 |
-
return list(_rules_items_list)
|
278 |
|
279 |
def get_all_memories_cached() -> list[dict]:
|
280 |
if not _initialized: initialize_memory_system()
|
|
|
149 |
raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
|
150 |
except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
|
151 |
|
152 |
+
rule_index, valid_rules = _build_faiss_index(raw_rules, lambda r: r)
|
153 |
_faiss_rules_index = rule_index
|
154 |
_rules_items_list = valid_rules
|
155 |
logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
|
|
|
160 |
else:
|
161 |
logger.error("Memory system initialization failed. Core components are not ready.")
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
|
164 |
global _memory_items_list, _faiss_memory_index
|
165 |
if not _initialized: initialize_memory_system()
|
|
|
185 |
return False, f"Error: {e}"
|
186 |
|
187 |
def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
|
|
|
188 |
if not _initialized: initialize_memory_system()
|
189 |
if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
|
|
|
|
|
|
|
190 |
|
191 |
try:
|
192 |
query_embedding = _embedder.encode([query], convert_to_tensor=False)
|
|
|
209 |
embedding = _embedder.encode([rule_text], convert_to_tensor=False)
|
210 |
_faiss_rules_index.add(np.array(embedding, dtype=np.float32))
|
211 |
_rules_items_list.append(rule_text)
|
|
|
212 |
|
213 |
if STORAGE_BACKEND == "SQLITE":
|
214 |
with _get_sqlite_connection() as conn: conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
|
|
|
220 |
return False, f"Error: {e}"
|
221 |
|
222 |
def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
|
|
|
223 |
if not _initialized: initialize_memory_system()
|
224 |
if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
|
225 |
|
|
|
|
|
|
|
226 |
try:
|
227 |
query_embedding = _embedder.encode([query], convert_to_tensor=False)
|
228 |
distances, indices = _faiss_rules_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_rules_index.ntotal))
|
|
|
235 |
global _rules_items_list, _faiss_rules_index
|
236 |
if not _initialized: initialize_memory_system()
|
237 |
rule_text_to_delete = rule_text_to_delete.strip()
|
238 |
+
|
239 |
try:
|
240 |
+
idx_to_remove = _rules_items_list.index(rule_text_to_delete)
|
241 |
+
except ValueError:
|
242 |
+
return False
|
243 |
+
|
244 |
+
try:
|
245 |
+
_faiss_rules_index.remove_ids(np.array([idx_to_remove], dtype='int64'))
|
246 |
+
del _rules_items_list[idx_to_remove]
|
247 |
|
248 |
if STORAGE_BACKEND == "SQLITE":
|
249 |
with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
|
250 |
+
elif STORAGE_BACKEND == "HF_DATASET":
|
251 |
+
# After removing, we need to push the new state of the list.
|
252 |
+
# Important: This can be slow if the dataset is large.
|
253 |
Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
|
254 |
+
|
255 |
return True
|
256 |
except Exception as e:
|
257 |
logger.error(f"Error removing rule: {e}", exc_info=True)
|
|
|
259 |
|
260 |
def get_all_rules_cached() -> list[str]:
|
261 |
if not _initialized: initialize_memory_system()
|
262 |
+
return sorted(list(_rules_items_list))
|
263 |
|
264 |
def get_all_memories_cached() -> list[dict]:
|
265 |
if not _initialized: initialize_memory_system()
|