broadfield-dev commited on
Commit
ef71876
·
verified ·
1 Parent(s): c9ee5e6

Update memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +99 -17
memory_logic.py CHANGED
@@ -1,4 +1,3 @@
1
- # memory_logic.py
2
  import os
3
  import json
4
  import time
@@ -60,18 +59,40 @@ def _get_sqlite_connection():
60
  os.makedirs(db_dir, exist_ok=True)
61
  return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def _build_faiss_index_from_json_strings(memory_items: list[str]) -> faiss.Index | None:
64
  if not memory_items or not _embedder:
65
  return faiss.IndexFlatL2(_dimension)
66
 
67
  texts_to_embed = []
68
- valid_indices = []
69
- for i, mem_json_str in enumerate(memory_items):
70
  try:
71
  mem_obj = json.loads(mem_json_str)
72
  text = f"User: {mem_obj.get('user_input', '')}\nAI: {mem_obj.get('bot_response', '')}\nTakeaway: {mem_obj.get('metrics', {}).get('takeaway', 'N/A')}"
73
  texts_to_embed.append(text)
74
- valid_indices.append(i)
75
  except json.JSONDecodeError:
76
  continue
77
 
@@ -103,8 +124,7 @@ def initialize_memory_system():
103
  return
104
 
105
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
106
- init_start_time = time.time()
107
-
108
  if not all([SentenceTransformer, faiss, np]):
109
  logger.error("Core RAG libraries not available. Cannot initialize semantic memory.")
110
  return
@@ -117,6 +137,9 @@ def initialize_memory_system():
117
  logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
118
  return
119
 
 
 
 
120
  long_term_mems = []
121
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
122
  try:
@@ -131,13 +154,10 @@ def initialize_memory_system():
131
  except Exception as e: logger.error(f"Error loading long-term memories from HF Dataset: {e}")
132
 
133
  _long_term_memory_items_list = long_term_mems
134
- logger.info(f"Loaded {len(_long_term_memory_items_list)} long-term memory items.")
135
  _faiss_long_term_memory_index = _build_faiss_index_from_json_strings(_long_term_memory_items_list)
136
- logger.info(f"Long-term memory FAISS index built. Total items: {_faiss_long_term_memory_index.ntotal if _faiss_long_term_memory_index else 'N/A'}")
137
-
138
  _short_term_memory_items_list = []
139
  _faiss_short_term_memory_index = faiss.IndexFlatL2(_dimension)
140
- logger.info("Short-term memory initialized (empty).")
141
 
142
  temp_rules_text = []
143
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
@@ -156,10 +176,8 @@ def initialize_memory_system():
156
  if _rules_items_list:
157
  rule_embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False)
158
  _faiss_rules_index.add(np.array(rule_embeddings, dtype=np.float32))
159
- logger.info(f"Rules FAISS index built. Total items: {_faiss_rules_index.ntotal if _faiss_rules_index else 'N/A'}")
160
 
161
  _initialized = True
162
- logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
163
 
164
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
165
  if not _initialized: initialize_memory_system()
@@ -184,7 +202,6 @@ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple
184
  all_mems_for_push = _long_term_memory_items_list + _short_term_memory_items_list
185
  Dataset.from_dict({"memory_json": list(set(all_mems_for_push))}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
186
 
187
- logger.info(f"Added memory. Short-term count: {_faiss_short_term_memory_index.ntotal}")
188
  return True, "Memory added successfully."
189
  except Exception as e:
190
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
@@ -203,13 +220,11 @@ def search_memories(query: str, k: int = 3, threshold: float = 1.0) -> tuple[lis
203
  best_dist = distances[0][0] if len(distances[0]) > 0 else float('inf')
204
 
205
  if best_dist < threshold:
206
- logger.info(f"Found relevant short-term memories (best distance: {best_dist:.4f}).")
207
  for i in indices[0]:
208
  res = json.loads(_short_term_memory_items_list[i])
209
  final_results[res['timestamp']] = res
210
  return list(final_results.values()), search_path
211
 
212
- logger.info("No relevant short-term memories found. Escalating to deep search on long-term memory.")
213
  search_path = "deep"
214
 
215
  if _faiss_long_term_memory_index and _faiss_long_term_memory_index.ntotal > 0:
@@ -245,7 +260,6 @@ def get_all_memories_cached() -> list[dict]:
245
  except: continue
246
  return unique_mem_dicts
247
 
248
- # --- The rest of the utility functions (add_rule, get_rules, clear functions) remain the same ---
249
  def add_rule_entry(rule_text: str):
250
  global _rules_items_list, _faiss_rules_index
251
  if not _initialized: initialize_memory_system()
@@ -270,6 +284,74 @@ def add_rule_entry(rule_text: str):
270
  logger.error(f"Error adding rule: {e}", exc_info=True)
271
  return False, str(e)
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def get_all_rules_cached() -> list[str]:
274
  if not _initialized: initialize_memory_system()
275
- return list(_rules_items_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
 
59
  os.makedirs(db_dir, exist_ok=True)
60
  return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
61
 
62
+ def _init_sqlite_tables():
63
+ if STORAGE_BACKEND != "SQLITE" or not sqlite3:
64
+ return
65
+ try:
66
+ with _get_sqlite_connection() as conn:
67
+ cursor = conn.cursor()
68
+ cursor.execute("""
69
+ CREATE TABLE IF NOT EXISTS memories (
70
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
71
+ memory_json TEXT NOT NULL,
72
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
73
+ )
74
+ """)
75
+ cursor.execute("""
76
+ CREATE TABLE IF NOT EXISTS rules (
77
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
78
+ rule_text TEXT NOT NULL UNIQUE,
79
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
80
+ )
81
+ """)
82
+ conn.commit()
83
+ except Exception as e:
84
+ logger.error(f"SQLite table initialization error: {e}", exc_info=True)
85
+
86
  def _build_faiss_index_from_json_strings(memory_items: list[str]) -> faiss.Index | None:
87
  if not memory_items or not _embedder:
88
  return faiss.IndexFlatL2(_dimension)
89
 
90
  texts_to_embed = []
91
+ for mem_json_str in memory_items:
 
92
  try:
93
  mem_obj = json.loads(mem_json_str)
94
  text = f"User: {mem_obj.get('user_input', '')}\nAI: {mem_obj.get('bot_response', '')}\nTakeaway: {mem_obj.get('metrics', {}).get('takeaway', 'N/A')}"
95
  texts_to_embed.append(text)
 
96
  except json.JSONDecodeError:
97
  continue
98
 
 
124
  return
125
 
126
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
127
+
 
128
  if not all([SentenceTransformer, faiss, np]):
129
  logger.error("Core RAG libraries not available. Cannot initialize semantic memory.")
130
  return
 
137
  logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
138
  return
139
 
140
+ if STORAGE_BACKEND == "SQLITE":
141
+ _init_sqlite_tables()
142
+
143
  long_term_mems = []
144
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
145
  try:
 
154
  except Exception as e: logger.error(f"Error loading long-term memories from HF Dataset: {e}")
155
 
156
  _long_term_memory_items_list = long_term_mems
 
157
  _faiss_long_term_memory_index = _build_faiss_index_from_json_strings(_long_term_memory_items_list)
158
+
 
159
  _short_term_memory_items_list = []
160
  _faiss_short_term_memory_index = faiss.IndexFlatL2(_dimension)
 
161
 
162
  temp_rules_text = []
163
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
 
176
  if _rules_items_list:
177
  rule_embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False)
178
  _faiss_rules_index.add(np.array(rule_embeddings, dtype=np.float32))
 
179
 
180
  _initialized = True
 
181
 
182
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
183
  if not _initialized: initialize_memory_system()
 
202
  all_mems_for_push = _long_term_memory_items_list + _short_term_memory_items_list
203
  Dataset.from_dict({"memory_json": list(set(all_mems_for_push))}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
204
 
 
205
  return True, "Memory added successfully."
206
  except Exception as e:
207
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
 
220
  best_dist = distances[0][0] if len(distances[0]) > 0 else float('inf')
221
 
222
  if best_dist < threshold:
 
223
  for i in indices[0]:
224
  res = json.loads(_short_term_memory_items_list[i])
225
  final_results[res['timestamp']] = res
226
  return list(final_results.values()), search_path
227
 
 
228
  search_path = "deep"
229
 
230
  if _faiss_long_term_memory_index and _faiss_long_term_memory_index.ntotal > 0:
 
260
  except: continue
261
  return unique_mem_dicts
262
 
 
263
  def add_rule_entry(rule_text: str):
264
  global _rules_items_list, _faiss_rules_index
265
  if not _initialized: initialize_memory_system()
 
284
  logger.error(f"Error adding rule: {e}", exc_info=True)
285
  return False, str(e)
286
 
287
+ def remove_rule_entry(rule_text_to_delete: str) -> bool:
288
+ global _rules_items_list, _faiss_rules_index
289
+ if not _initialized: initialize_memory_system()
290
+ if not _embedder or not _faiss_rules_index: return False
291
+ rule_text_to_delete = rule_text_to_delete.strip()
292
+ if rule_text_to_delete not in _rules_items_list:
293
+ return False
294
+ try:
295
+ _rules_items_list.remove(rule_text_to_delete)
296
+ _rules_items_list.sort()
297
+ new_faiss_rules_index = faiss.IndexFlatL2(_dimension)
298
+ if _rules_items_list:
299
+ embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False)
300
+ embeddings_np = np.array(embeddings, dtype=np.float32)
301
+ if embeddings_np.ndim == 2 and embeddings_np.shape[1] == _dimension:
302
+ new_faiss_rules_index.add(embeddings_np)
303
+ else:
304
+ _rules_items_list.append(rule_text_to_delete)
305
+ _rules_items_list.sort()
306
+ return False
307
+ _faiss_rules_index = new_faiss_rules_index
308
+ if STORAGE_BACKEND == "SQLITE" and sqlite3:
309
+ with _get_sqlite_connection() as conn:
310
+ conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,))
311
+ conn.commit()
312
+ elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
313
+ Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
314
+ return True
315
+ except Exception as e:
316
+ logger.error(f"Error removing rule entry: {e}", exc_info=True)
317
+ return False
318
+
319
  def get_all_rules_cached() -> list[str]:
320
  if not _initialized: initialize_memory_system()
321
+ return list(_rules_items_list)
322
+
323
+ def clear_all_memory_data_backend() -> bool:
324
+ global _long_term_memory_items_list, _short_term_memory_items_list, _faiss_long_term_memory_index, _faiss_short_term_memory_index
325
+ if not _initialized: initialize_memory_system()
326
+ success = True
327
+ try:
328
+ if STORAGE_BACKEND == "SQLITE" and sqlite3:
329
+ with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
330
+ elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
331
+ Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
332
+ _long_term_memory_items_list = []
333
+ _short_term_memory_items_list = []
334
+ if _faiss_long_term_memory_index: _faiss_long_term_memory_index.reset()
335
+ if _faiss_short_term_memory_index: _faiss_short_term_memory_index.reset()
336
+ except Exception as e:
337
+ logger.error(f"Error clearing all memory data: {e}")
338
+ success = False
339
+ return success
340
+
341
+ def clear_all_rules_data_backend() -> bool:
342
+ global _rules_items_list, _faiss_rules_index
343
+ if not _initialized: initialize_memory_system()
344
+ success = True
345
+ try:
346
+ if STORAGE_BACKEND == "SQLITE" and sqlite3:
347
+ with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules"); conn.commit()
348
+ elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
349
+ Dataset.from_dict({"rule_text": []}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
350
+ _rules_items_list = []
351
+ if _faiss_rules_index: _faiss_rules_index.reset()
352
+ except Exception as e:
353
+ logger.error(f"Error clearing all rules data: {e}")
354
+ success = False
355
+ return success
356
+
357
+ def save_faiss_indices_to_disk(): pass