broadfield-dev commited on
Commit
ea898c4
·
verified ·
1 Parent(s): 55aa812

Update memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +16 -29
memory_logic.py CHANGED
@@ -1,4 +1,3 @@
1
- # memory_logic.py
2
  import os
3
  import json
4
  import time
@@ -7,7 +6,6 @@ import logging
7
  import re
8
  import threading
9
 
10
- # Conditionally import heavy dependencies
11
  try:
12
  from sentence_transformers import SentenceTransformer
13
  import faiss
@@ -28,21 +26,16 @@ except ImportError:
28
  load_dataset, Dataset = None, None
29
  logging.warning("datasets library not installed. Hugging Face Dataset backend will be unavailable.")
30
 
31
-
32
  logger = logging.getLogger(__name__)
33
- # Suppress verbose logs from dependencies
34
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
35
  if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING)
36
 
37
-
38
- # --- Configuration (Read directly from environment variables) ---
39
- STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper() #HF_DATASET, RAM, SQLITE
40
  SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db")
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
  HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain")
43
  HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules")
44
 
45
- # --- Globals for RAG within this module ---
46
  _embedder = None
47
  _dimension = 384
48
  _faiss_memory_index = None
@@ -74,10 +67,10 @@ def _init_sqlite_tables():
74
  def _build_faiss_index(items_list, text_extraction_fn):
75
  if not _embedder:
76
  logger.error("Cannot build FAISS index: Embedder not available.")
77
- return None
78
 
79
  index = faiss.IndexFlatL2(_dimension)
80
- if not items_list: return index
81
 
82
  texts_to_embed, valid_items = [], []
83
  for item in items_list:
@@ -89,7 +82,7 @@ def _build_faiss_index(items_list, text_extraction_fn):
89
 
90
  if not texts_to_embed:
91
  logger.warning("No valid items to embed for FAISS index after filtering.")
92
- return index
93
 
94
  try:
95
  embeddings = _embedder.encode(texts_to_embed, convert_to_tensor=False, show_progress_bar=False)
@@ -97,7 +90,6 @@ def _build_faiss_index(items_list, text_extraction_fn):
97
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(texts_to_embed):
98
  index.add(embeddings_np)
99
  logger.info(f"FAISS index built with {index.ntotal} / {len(items_list)} items.")
100
- # Important: The original items_list is returned so we can update the global list to match the index
101
  return index, valid_items
102
  else:
103
  logger.error(f"FAISS build failed: Embeddings shape error.")
@@ -126,11 +118,10 @@ def initialize_memory_system():
126
  _dimension = _embedder.get_sentence_embedding_dimension() or 384
127
  except Exception as e:
128
  logger.critical(f"FATAL: Could not load SentenceTransformer model. Semantic search disabled. Error: {e}", exc_info=True)
129
- return # Stop initialization if embedder fails
130
 
131
  if STORAGE_BACKEND == "SQLITE": _init_sqlite_tables()
132
 
133
- # Load raw data
134
  raw_mems = []
135
  if STORAGE_BACKEND == "SQLITE":
136
  try: raw_mems = [row[0] for row in _get_sqlite_connection().execute("SELECT memory_json FROM memories")]
@@ -142,13 +133,11 @@ def initialize_memory_system():
142
  raw_mems = [m for m in dataset["train"]["memory_json"] if isinstance(m, str) and m.strip()]
143
  except Exception as e: logger.error(f"Error loading memories from HF Dataset: {e}", exc_info=True)
144
 
145
- # Build Memory Index and get validated list
146
  mem_index, valid_mems = _build_faiss_index(raw_mems, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
147
  _faiss_memory_index = mem_index
148
- _memory_items_list = valid_mems # Use the validated list that matches the index
149
  logger.info(f"Loaded and indexed {len(_memory_items_list)} memories.")
150
 
151
- # Load Rules
152
  raw_rules = []
153
  if STORAGE_BACKEND == "SQLITE":
154
  try: raw_rules = [row[0] for row in _get_sqlite_connection().execute("SELECT rule_text FROM rules")]
@@ -160,13 +149,11 @@ def initialize_memory_system():
160
  raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
161
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
162
 
163
- # Build Rules Index and get validated list
164
  rule_index, valid_rules = _build_faiss_index(sorted(list(set(raw_rules))), lambda r: r)
165
  _faiss_rules_index = rule_index
166
- _rules_items_list = valid_rules # Use the validated list that matches the index
167
  logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
168
 
169
- # Only mark as initialized if the core components are ready
170
  if _embedder and _faiss_memory_index is not None and _faiss_rules_index is not None:
171
  _initialized = True
172
  logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
@@ -174,18 +161,19 @@ def initialize_memory_system():
174
  logger.error("Memory system initialization failed. Core components are not ready.")
175
 
176
  def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
 
177
  if not index or index.ntotal != len(items_list):
178
  logger.warning(f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). Rebuilding...")
179
  new_index, valid_items = _build_faiss_index(items_list, text_extraction_fn)
180
- # This function is now stateful and modifies the global list to match the new index
181
- if isinstance(items_list, list) and isinstance(valid_items, list):
182
- # A bit of a hack to update the global list from here
183
- items_list.clear()
184
- items_list.extend(valid_items)
185
  return new_index
186
  return index
187
 
188
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
 
189
  if not _initialized: initialize_memory_system()
190
  if not _embedder or _faiss_memory_index is None: return False, "Memory system not ready."
191
 
@@ -225,6 +213,7 @@ def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
225
  return []
226
 
227
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
 
228
  if not _initialized: initialize_memory_system()
229
  if not _embedder or _faiss_rules_index is None: return False, "Rule system not ready."
230
 
@@ -264,11 +253,11 @@ def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
264
  return []
265
 
266
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
 
267
  if not _initialized: initialize_memory_system()
268
  rule_text_to_delete = rule_text_to_delete.strip()
269
  if rule_text_to_delete not in _rules_items_list: return False
270
  try:
271
- # Rebuild index and list without the deleted rule
272
  new_list = [r for r in _rules_items_list if r != rule_text_to_delete]
273
  _faiss_rules_index, _rules_items_list = _build_faiss_index(new_list, lambda r: r)
274
 
@@ -292,7 +281,7 @@ def get_all_memories_cached() -> list[dict]:
292
  try:
293
  valid_mems.append(json.loads(m_str))
294
  except json.JSONDecodeError:
295
- continue # Skip corrupted data for UI display
296
  return valid_mems
297
 
298
  def clear_all_memory_data_backend() -> bool:
@@ -324,8 +313,6 @@ def clear_all_rules_data_backend() -> bool:
324
  logger.error(f"Error clearing rules data: {e}"); return False
325
 
326
  def save_faiss_indices_to_disk():
327
- # This function is primarily for the RAM backend, which is not the focus here.
328
- # It's kept for compatibility.
329
  if not _initialized or not faiss: return
330
  faiss_dir = "app_data/faiss_indices"
331
  os.makedirs(faiss_dir, exist_ok=True)
 
 
1
  import os
2
  import json
3
  import time
 
6
  import re
7
  import threading
8
 
 
9
  try:
10
  from sentence_transformers import SentenceTransformer
11
  import faiss
 
26
  load_dataset, Dataset = None, None
27
  logging.warning("datasets library not installed. Hugging Face Dataset backend will be unavailable.")
28
 
 
29
  logger = logging.getLogger(__name__)
 
30
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
31
  if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING)
32
 
33
+ STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper()
 
 
34
  SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db")
35
  HF_TOKEN = os.getenv("HF_TOKEN")
36
  HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain")
37
  HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules")
38
 
 
39
  _embedder = None
40
  _dimension = 384
41
  _faiss_memory_index = None
 
67
  def _build_faiss_index(items_list, text_extraction_fn):
68
  if not _embedder:
69
  logger.error("Cannot build FAISS index: Embedder not available.")
70
+ return None, []
71
 
72
  index = faiss.IndexFlatL2(_dimension)
73
+ if not items_list: return index, []
74
 
75
  texts_to_embed, valid_items = [], []
76
  for item in items_list:
 
82
 
83
  if not texts_to_embed:
84
  logger.warning("No valid items to embed for FAISS index after filtering.")
85
+ return index, []
86
 
87
  try:
88
  embeddings = _embedder.encode(texts_to_embed, convert_to_tensor=False, show_progress_bar=False)
 
90
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(texts_to_embed):
91
  index.add(embeddings_np)
92
  logger.info(f"FAISS index built with {index.ntotal} / {len(items_list)} items.")
 
93
  return index, valid_items
94
  else:
95
  logger.error(f"FAISS build failed: Embeddings shape error.")
 
118
  _dimension = _embedder.get_sentence_embedding_dimension() or 384
119
  except Exception as e:
120
  logger.critical(f"FATAL: Could not load SentenceTransformer model. Semantic search disabled. Error: {e}", exc_info=True)
121
+ return
122
 
123
  if STORAGE_BACKEND == "SQLITE": _init_sqlite_tables()
124
 
 
125
  raw_mems = []
126
  if STORAGE_BACKEND == "SQLITE":
127
  try: raw_mems = [row[0] for row in _get_sqlite_connection().execute("SELECT memory_json FROM memories")]
 
133
  raw_mems = [m for m in dataset["train"]["memory_json"] if isinstance(m, str) and m.strip()]
134
  except Exception as e: logger.error(f"Error loading memories from HF Dataset: {e}", exc_info=True)
135
 
 
136
  mem_index, valid_mems = _build_faiss_index(raw_mems, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
137
  _faiss_memory_index = mem_index
138
+ _memory_items_list = valid_mems
139
  logger.info(f"Loaded and indexed {len(_memory_items_list)} memories.")
140
 
 
141
  raw_rules = []
142
  if STORAGE_BACKEND == "SQLITE":
143
  try: raw_rules = [row[0] for row in _get_sqlite_connection().execute("SELECT rule_text FROM rules")]
 
149
  raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
150
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
151
 
 
152
  rule_index, valid_rules = _build_faiss_index(sorted(list(set(raw_rules))), lambda r: r)
153
  _faiss_rules_index = rule_index
154
+ _rules_items_list = valid_rules
155
  logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
156
 
 
157
  if _embedder and _faiss_memory_index is not None and _faiss_rules_index is not None:
158
  _initialized = True
159
  logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
 
161
  logger.error("Memory system initialization failed. Core components are not ready.")
162
 
163
  def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
164
+ global _memory_items_list, _rules_items_list
165
  if not index or index.ntotal != len(items_list):
166
  logger.warning(f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). Rebuilding...")
167
  new_index, valid_items = _build_faiss_index(items_list, text_extraction_fn)
168
+ if items_list is _memory_items_list:
169
+ _memory_items_list = valid_items
170
+ elif items_list is _rules_items_list:
171
+ _rules_items_list = valid_items
 
172
  return new_index
173
  return index
174
 
175
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
176
+ global _memory_items_list, _faiss_memory_index
177
  if not _initialized: initialize_memory_system()
178
  if not _embedder or _faiss_memory_index is None: return False, "Memory system not ready."
179
 
 
213
  return []
214
 
215
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
216
+ global _rules_items_list, _faiss_rules_index
217
  if not _initialized: initialize_memory_system()
218
  if not _embedder or _faiss_rules_index is None: return False, "Rule system not ready."
219
 
 
253
  return []
254
 
255
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
256
+ global _rules_items_list, _faiss_rules_index
257
  if not _initialized: initialize_memory_system()
258
  rule_text_to_delete = rule_text_to_delete.strip()
259
  if rule_text_to_delete not in _rules_items_list: return False
260
  try:
 
261
  new_list = [r for r in _rules_items_list if r != rule_text_to_delete]
262
  _faiss_rules_index, _rules_items_list = _build_faiss_index(new_list, lambda r: r)
263
 
 
281
  try:
282
  valid_mems.append(json.loads(m_str))
283
  except json.JSONDecodeError:
284
+ continue
285
  return valid_mems
286
 
287
  def clear_all_memory_data_backend() -> bool:
 
313
  logger.error(f"Error clearing rules data: {e}"); return False
314
 
315
  def save_faiss_indices_to_disk():
 
 
316
  if not _initialized or not faiss: return
317
  faiss_dir = "app_data/faiss_indices"
318
  os.makedirs(faiss_dir, exist_ok=True)