broadfield-dev commited on
Commit
54137b6
·
verified ·
1 Parent(s): 3b79be2

Update memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +14 -29
memory_logic.py CHANGED
@@ -149,7 +149,7 @@ def initialize_memory_system():
149
  raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
150
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
151
 
152
- rule_index, valid_rules = _build_faiss_index(sorted(list(set(raw_rules))), lambda r: r)
153
  _faiss_rules_index = rule_index
154
  _rules_items_list = valid_rules
155
  logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
@@ -160,18 +160,6 @@ def initialize_memory_system():
160
  else:
161
  logger.error("Memory system initialization failed. Core components are not ready.")
162
 
163
- def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
164
- global _memory_items_list, _rules_items_list
165
- if not index or index.ntotal != len(items_list):
166
- logger.warning(f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). Rebuilding...")
167
- new_index, valid_items = _build_faiss_index(items_list, text_extraction_fn)
168
- if items_list is _memory_items_list:
169
- _memory_items_list = valid_items
170
- elif items_list is _rules_items_list:
171
- _rules_items_list = valid_items
172
- return new_index
173
- return index
174
-
175
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
176
  global _memory_items_list, _faiss_memory_index
177
  if not _initialized: initialize_memory_system()
@@ -197,12 +185,8 @@ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple
197
  return False, f"Error: {e}"
198
 
199
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
200
- global _faiss_memory_index
201
  if not _initialized: initialize_memory_system()
202
  if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
203
-
204
- _faiss_memory_index = _verify_and_rebuild_if_needed(_faiss_memory_index, _memory_items_list, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
205
- if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
206
 
207
  try:
208
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
@@ -225,7 +209,6 @@ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
225
  embedding = _embedder.encode([rule_text], convert_to_tensor=False)
226
  _faiss_rules_index.add(np.array(embedding, dtype=np.float32))
227
  _rules_items_list.append(rule_text)
228
- _rules_items_list.sort()
229
 
230
  if STORAGE_BACKEND == "SQLITE":
231
  with _get_sqlite_connection() as conn: conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
@@ -237,13 +220,9 @@ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
237
  return False, f"Error: {e}"
238
 
239
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
240
- global _faiss_rules_index
241
  if not _initialized: initialize_memory_system()
242
  if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
243
 
244
- _faiss_rules_index = _verify_and_rebuild_if_needed(_faiss_rules_index, _rules_items_list, lambda r: r)
245
- if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
246
-
247
  try:
248
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
249
  distances, indices = _faiss_rules_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_rules_index.ntotal))
@@ -256,17 +235,23 @@ def remove_rule_entry(rule_text_to_delete: str) -> bool:
256
  global _rules_items_list, _faiss_rules_index
257
  if not _initialized: initialize_memory_system()
258
  rule_text_to_delete = rule_text_to_delete.strip()
259
- if rule_text_to_delete not in _rules_items_list: return False
260
  try:
261
- new_list = [r for r in _rules_items_list if r != rule_text_to_delete]
262
- new_index, valid_items = _build_faiss_index(new_list, lambda r: r)
263
- _faiss_rules_index = new_index
264
- _rules_items_list = valid_items
 
 
 
265
 
266
  if STORAGE_BACKEND == "SQLITE":
267
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
268
- elif STORAGE_BACKEND == "HF_DATASET" and _rules_items_list:
 
 
269
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
 
270
  return True
271
  except Exception as e:
272
  logger.error(f"Error removing rule: {e}", exc_info=True)
@@ -274,7 +259,7 @@ def remove_rule_entry(rule_text_to_delete: str) -> bool:
274
 
275
  def get_all_rules_cached() -> list[str]:
276
  if not _initialized: initialize_memory_system()
277
- return list(_rules_items_list)
278
 
279
  def get_all_memories_cached() -> list[dict]:
280
  if not _initialized: initialize_memory_system()
 
149
  raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
150
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
151
 
152
+ rule_index, valid_rules = _build_faiss_index(raw_rules, lambda r: r)
153
  _faiss_rules_index = rule_index
154
  _rules_items_list = valid_rules
155
  logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
 
160
  else:
161
  logger.error("Memory system initialization failed. Core components are not ready.")
162
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
164
  global _memory_items_list, _faiss_memory_index
165
  if not _initialized: initialize_memory_system()
 
185
  return False, f"Error: {e}"
186
 
187
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
 
188
  if not _initialized: initialize_memory_system()
189
  if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
 
 
 
190
 
191
  try:
192
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
 
209
  embedding = _embedder.encode([rule_text], convert_to_tensor=False)
210
  _faiss_rules_index.add(np.array(embedding, dtype=np.float32))
211
  _rules_items_list.append(rule_text)
 
212
 
213
  if STORAGE_BACKEND == "SQLITE":
214
  with _get_sqlite_connection() as conn: conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
 
220
  return False, f"Error: {e}"
221
 
222
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
 
223
  if not _initialized: initialize_memory_system()
224
  if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
225
 
 
 
 
226
  try:
227
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
228
  distances, indices = _faiss_rules_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_rules_index.ntotal))
 
235
  global _rules_items_list, _faiss_rules_index
236
  if not _initialized: initialize_memory_system()
237
  rule_text_to_delete = rule_text_to_delete.strip()
238
+
239
  try:
240
+ idx_to_remove = _rules_items_list.index(rule_text_to_delete)
241
+ except ValueError:
242
+ return False
243
+
244
+ try:
245
+ _faiss_rules_index.remove_ids(np.array([idx_to_remove], dtype='int64'))
246
+ del _rules_items_list[idx_to_remove]
247
 
248
  if STORAGE_BACKEND == "SQLITE":
249
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
250
+ elif STORAGE_BACKEND == "HF_DATASET":
251
+ # After removing, we need to push the new state of the list.
252
+ # Important: This can be slow if the dataset is large.
253
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
254
+
255
  return True
256
  except Exception as e:
257
  logger.error(f"Error removing rule: {e}", exc_info=True)
 
259
 
260
  def get_all_rules_cached() -> list[str]:
261
  if not _initialized: initialize_memory_system()
262
+ return sorted(list(_rules_items_list))
263
 
264
  def get_all_memories_cached() -> list[dict]:
265
  if not _initialized: initialize_memory_system()