broadfield-dev commited on
Commit
60c007c
·
verified ·
1 Parent(s): f07b8e0

Update memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +117 -165
memory_logic.py CHANGED
@@ -32,181 +32,162 @@ except ImportError:
32
  logger = logging.getLogger(__name__)
33
  # Suppress verbose logs from dependencies
34
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
35
- if logging.getLogger(lib_name): # Check if logger exists
36
- logging.getLogger(lib_name).setLevel(logging.WARNING)
37
 
38
 
39
  # --- Configuration (Read directly from environment variables) ---
40
  STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper() #HF_DATASET, RAM, SQLITE
41
- SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db") # Changed default path
42
  HF_TOKEN = os.getenv("HF_TOKEN")
43
- HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain") # Example
44
- HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules") # Example
45
 
46
  # --- Globals for RAG within this module ---
47
  _embedder = None
48
- _dimension = 384 # Default, will be set by embedder
49
  _faiss_memory_index = None
50
- _memory_items_list = [] # Stores JSON strings of memory objects for RAM, or loaded from DB/HF
51
  _faiss_rules_index = None
52
- _rules_items_list = [] # Stores rule text strings
53
 
54
  _initialized = False
55
  _init_lock = threading.Lock()
56
 
57
- # --- Helper: SQLite Connection ---
58
  def _get_sqlite_connection():
59
- if not sqlite3:
60
- raise ImportError("sqlite3 module is required for SQLite backend but not found.")
61
  db_dir = os.path.dirname(SQLITE_DB_PATH)
62
- if db_dir and not os.path.exists(db_dir):
63
- os.makedirs(db_dir, exist_ok=True)
64
  return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
65
 
66
  def _init_sqlite_tables():
67
- if STORAGE_BACKEND != "SQLITE" or not sqlite3:
68
- return
69
  try:
70
  with _get_sqlite_connection() as conn:
71
  cursor = conn.cursor()
72
- cursor.execute("""
73
- CREATE TABLE IF NOT EXISTS memories (
74
- id INTEGER PRIMARY KEY AUTOINCREMENT,
75
- memory_json TEXT NOT NULL,
76
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
77
- )
78
- """)
79
- cursor.execute("""
80
- CREATE TABLE IF NOT EXISTS rules (
81
- id INTEGER PRIMARY KEY AUTOINCREMENT,
82
- rule_text TEXT NOT NULL UNIQUE,
83
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
84
- )
85
- """)
86
  conn.commit()
87
- logger.info("SQLite tables for memories and rules checked/created.")
88
  except Exception as e:
89
  logger.error(f"SQLite table initialization error: {e}", exc_info=True)
90
 
91
-
92
  def _build_faiss_index(items_list, text_extraction_fn):
93
- """Builds a FAISS index from a list of items."""
94
  if not _embedder:
95
  logger.error("Cannot build FAISS index: Embedder not available.")
96
  return None
97
 
98
  index = faiss.IndexFlatL2(_dimension)
99
- if not items_list:
 
 
 
 
 
 
 
 
 
 
 
100
  return index
101
 
102
- logger.info(f"Building FAISS index for {len(items_list)} items...")
103
- texts_to_embed = [text_extraction_fn(item) for item in items_list]
104
-
105
  try:
106
  embeddings = _embedder.encode(texts_to_embed, convert_to_tensor=False, show_progress_bar=False)
107
  embeddings_np = np.array(embeddings, dtype=np.float32)
108
- if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(items_list):
109
  index.add(embeddings_np)
110
- logger.info(f"FAISS index built successfully with {index.ntotal} items.")
 
 
111
  else:
112
- logger.error(f"FAISS build failed: Embeddings shape error. Expected ({len(items_list)}, {_dimension}), Got {getattr(embeddings_np, 'shape', 'N/A')}")
113
- return faiss.IndexFlatL2(_dimension) # Return empty index on failure
114
  except Exception as e:
115
  logger.error(f"Error building FAISS index: {e}", exc_info=True)
116
- return faiss.IndexFlatL2(_dimension) # Return empty index on failure
117
-
118
- return index
119
 
120
- # --- Initialization ---
121
  def initialize_memory_system():
122
  global _initialized, _embedder, _dimension, _faiss_memory_index, _memory_items_list, _faiss_rules_index, _rules_items_list
123
 
124
  with _init_lock:
125
- if _initialized:
126
- return
127
 
128
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
129
  init_start_time = time.time()
130
 
131
- if not SentenceTransformer or not faiss or not np:
132
  logger.error("Core RAG libraries not available. Cannot initialize semantic memory.")
133
  return
134
 
135
  if not _embedder:
136
  try:
137
- logger.info("Loading SentenceTransformer model (all-MiniLM-L6-v2)...")
138
  _embedder = SentenceTransformer('all-MiniLM-L6-v2', cache_folder="./sentence_transformer_cache")
139
  _dimension = _embedder.get_sentence_embedding_dimension() or 384
140
  except Exception as e:
141
- logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
142
- return
143
 
144
  if STORAGE_BACKEND == "SQLITE": _init_sqlite_tables()
145
 
146
- # Load Memories from persistent storage
147
- temp_memories_json = []
148
  if STORAGE_BACKEND == "SQLITE":
149
- try: temp_memories_json = [row[0] for row in _get_sqlite_connection().execute("SELECT memory_json FROM memories")]
150
  except Exception as e: logger.error(f"Error loading memories from SQLite: {e}")
151
  elif STORAGE_BACKEND == "HF_DATASET":
152
  try:
153
- logger.info(f"Loading memories from HF Dataset: {HF_MEMORY_DATASET_REPO}")
154
  dataset = load_dataset(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
155
  if "train" in dataset and "memory_json" in dataset["train"].column_names:
156
- temp_memories_json = [m for m in dataset["train"]["memory_json"] if isinstance(m, str) and m.strip()]
157
- logger.info(f"Loaded {len(temp_memories_json)} valid memories from HF Dataset.")
158
- else: logger.warning(f"HF Dataset {HF_MEMORY_DATASET_REPO} has no 'train' split or 'memory_json' column.")
159
  except Exception as e: logger.error(f"Error loading memories from HF Dataset: {e}", exc_info=True)
160
 
161
- _memory_items_list = temp_memories_json
162
-
163
- # Build Memory FAISS Index
164
- _faiss_memory_index = _build_faiss_index(
165
- _memory_items_list,
166
- lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}\nTakeaway: {json.loads(m).get('metrics', {}).get('takeaway', 'N/A')}"
167
- )
168
-
169
- # Load Rules from persistent storage
170
- temp_rules_text = []
171
  if STORAGE_BACKEND == "SQLITE":
172
- try: temp_rules_text = [row[0] for row in _get_sqlite_connection().execute("SELECT rule_text FROM rules")]
173
  except Exception as e: logger.error(f"Error loading rules from SQLite: {e}")
174
  elif STORAGE_BACKEND == "HF_DATASET":
175
  try:
176
- logger.info(f"Loading rules from HF Dataset: {HF_RULES_DATASET_REPO}")
177
  dataset = load_dataset(HF_RULES_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
178
  if "train" in dataset and "rule_text" in dataset["train"].column_names:
179
- temp_rules_text = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
180
- logger.info(f"Loaded {len(temp_rules_text)} valid rules from HF Dataset.")
181
- else: logger.warning(f"HF Dataset {HF_RULES_DATASET_REPO} has no 'train' split or 'rule_text' column.")
182
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
183
 
184
- _rules_items_list = sorted(list(set(temp_rules_text)))
 
 
 
 
185
 
186
- # Build Rules FAISS Index
187
- _faiss_rules_index = _build_faiss_index(_rules_items_list, lambda r: r)
188
-
189
- _initialized = True
190
- logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
191
-
192
 
193
  def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
194
- """Self-healing function to ensure FAISS index is synced with the item list."""
195
  if not index or index.ntotal != len(items_list):
196
- logger.warning(
197
- f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). "
198
- "Rebuilding index from in-memory cache."
199
- )
200
- return _build_faiss_index(items_list, text_extraction_fn)
 
 
 
201
  return index
202
 
203
-
204
- # --- Memory Operations (Semantic) ---
205
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
206
- global _memory_items_list, _faiss_memory_index
207
  if not _initialized: initialize_memory_system()
208
- if not _embedder or not _faiss_memory_index:
209
- return False, "Memory system not ready for adding entries."
210
 
211
  memory_obj = {"user_input": user_input, "metrics": metrics, "bot_response": bot_response, "timestamp": datetime.utcnow().isoformat()}
212
  memory_json_str = json.dumps(memory_obj)
@@ -214,168 +195,139 @@ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple
214
 
215
  try:
216
  embedding = _embedder.encode([text_to_embed], convert_to_tensor=False)
217
- embedding_np = np.array(embedding, dtype=np.float32)
218
-
219
- _faiss_memory_index.add(embedding_np)
220
  _memory_items_list.append(memory_json_str)
221
 
222
  if STORAGE_BACKEND == "SQLITE":
223
- with _get_sqlite_connection() as conn:
224
- conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json_str,)); conn.commit()
225
  elif STORAGE_BACKEND == "HF_DATASET":
226
  Dataset.from_dict({"memory_json": list(_memory_items_list)}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
227
 
228
- logger.info(f"Added memory. Cache size: {len(_memory_items_list)}, FAISS size: {_faiss_memory_index.ntotal}")
229
- return True, "Memory added successfully."
230
  except Exception as e:
231
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
232
- return False, f"Error adding memory: {e}"
233
 
234
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
235
  global _faiss_memory_index
236
  if not _initialized: initialize_memory_system()
237
-
238
- # Self-healing: Verify index is synced with cache, rebuild if not.
239
- _faiss_memory_index = _verify_and_rebuild_if_needed(
240
- _faiss_memory_index, _memory_items_list,
241
- lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}\nTakeaway: {json.loads(m).get('metrics', {}).get('takeaway', 'N/A')}"
242
- )
243
-
244
- if not _faiss_memory_index or _faiss_memory_index.ntotal == 0:
245
- logger.debug("Cannot retrieve memories: index is empty.")
246
- return []
247
 
 
 
 
248
  try:
249
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
250
- query_embedding_np = np.array(query_embedding, dtype=np.float32)
251
- distances, indices = _faiss_memory_index.search(query_embedding_np, min(k, _faiss_memory_index.ntotal))
252
-
253
- results = [json.loads(_memory_items_list[i]) for i in indices[0] if 0 <= i < len(_memory_items_list)]
254
- logger.info(f"Retrieved {len(results)} memories for query: '{query[:50]}...'")
255
- return results
256
  except Exception as e:
257
- logger.error(f"Error retrieving memories semantically: {e}", exc_info=True)
258
  return []
259
 
260
-
261
- # --- Rule (Insight) Operations (Semantic) ---
262
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
263
- global _rules_items_list, _faiss_rules_index
264
  if not _initialized: initialize_memory_system()
 
265
 
266
  rule_text = rule_text.strip()
267
- if not rule_text or "duplicate" == rule_text or rule_text in _rules_items_list:
268
- return False, "duplicate or invalid"
269
- if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", rule_text, re.I):
270
- return False, "Invalid rule format."
271
 
272
  try:
273
  embedding = _embedder.encode([rule_text], convert_to_tensor=False)
274
- embedding_np = np.array(embedding, dtype=np.float32)
275
- _faiss_rules_index.add(embedding_np)
276
  _rules_items_list.append(rule_text)
277
  _rules_items_list.sort()
278
 
279
  if STORAGE_BACKEND == "SQLITE":
280
- with _get_sqlite_connection() as conn:
281
- conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
282
  elif STORAGE_BACKEND == "HF_DATASET":
283
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
284
- return True, "Rule added successfully."
285
  except Exception as e:
286
- logger.error(f"Error adding rule entry: {e}", exc_info=True)
287
- return False, f"Error adding rule: {e}"
288
 
289
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
290
  global _faiss_rules_index
291
  if not _initialized: initialize_memory_system()
292
-
 
293
  _faiss_rules_index = _verify_and_rebuild_if_needed(_faiss_rules_index, _rules_items_list, lambda r: r)
294
-
295
  if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
 
296
  try:
297
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
298
- query_embedding_np = np.array(query_embedding, dtype=np.float32)
299
- distances, indices = _faiss_rules_index.search(query_embedding_np, min(k, _faiss_rules_index.ntotal))
300
  return [_rules_items_list[i] for i in indices[0] if 0 <= i < len(_rules_items_list)]
301
  except Exception as e:
302
- logger.error(f"Error retrieving rules semantically: {e}", exc_info=True)
303
  return []
304
 
305
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
306
- global _rules_items_list, _faiss_rules_index
307
  if not _initialized: initialize_memory_system()
308
  rule_text_to_delete = rule_text_to_delete.strip()
309
  if rule_text_to_delete not in _rules_items_list: return False
310
  try:
311
- _rules_items_list.remove(rule_text_to_delete)
312
- _faiss_rules_index = _build_faiss_index(_rules_items_list, lambda r: r)
 
313
 
314
  if STORAGE_BACKEND == "SQLITE":
315
- with _get_sqlite_connection() as conn:
316
- conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
317
  elif STORAGE_BACKEND == "HF_DATASET":
318
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
319
  return True
320
  except Exception as e:
321
- logger.error(f"Error removing rule entry: {e}", exc_info=True)
322
  return False
323
 
324
- # --- Utility functions to get all data (for UI display, etc.) ---
325
  def get_all_rules_cached() -> list[str]:
326
  if not _initialized: initialize_memory_system()
327
  return list(_rules_items_list)
328
 
329
  def get_all_memories_cached() -> list[dict]:
330
  if not _initialized: initialize_memory_system()
331
- return [json.loads(m) for m in _memory_items_list if m]
 
 
 
 
 
 
332
 
333
  def clear_all_memory_data_backend() -> bool:
334
  global _memory_items_list, _faiss_memory_index
335
  if not _initialized: initialize_memory_system()
336
- _memory_items_list = []
337
  if _faiss_memory_index: _faiss_memory_index.reset()
338
  try:
339
  if STORAGE_BACKEND == "SQLITE":
340
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
341
  elif STORAGE_BACKEND == "HF_DATASET":
342
  Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
343
- logger.info("All memories cleared.")
344
  return True
345
  except Exception as e:
346
- logger.error(f"Error clearing all memory data: {e}"); return False
347
 
348
  def clear_all_rules_data_backend() -> bool:
349
  global _rules_items_list, _faiss_rules_index
350
  if not _initialized: initialize_memory_system()
351
- _rules_items_list = []
352
  if _faiss_rules_index: _faiss_rules_index.reset()
353
  try:
354
  if STORAGE_BACKEND == "SQLITE":
355
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules"); conn.commit()
356
  elif STORAGE_BACKEND == "HF_DATASET":
357
  Dataset.from_dict({"rule_text": []}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
358
- logger.info("All rules cleared.")
359
  return True
360
  except Exception as e:
361
- logger.error(f"Error clearing all rules data: {e}"); return False
362
-
363
- FAISS_MEMORY_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "memory_index.faiss")
364
- FAISS_RULES_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "rules_index.faiss")
365
 
366
  def save_faiss_indices_to_disk():
 
 
367
  if not _initialized or not faiss: return
368
- faiss_dir = os.path.dirname(FAISS_MEMORY_PATH)
369
- if not os.path.exists(faiss_dir): os.makedirs(faiss_dir, exist_ok=True)
370
- if _faiss_memory_index and _faiss_memory_index.ntotal > 0:
371
- faiss.write_index(_faiss_memory_index, FAISS_MEMORY_PATH)
372
- if _faiss_rules_index and _faiss_rules_index.ntotal > 0:
373
- faiss.write_index(_faiss_rules_index, FAISS_RULES_PATH)
374
-
375
- def load_faiss_indices_from_disk():
376
- global _faiss_memory_index, _faiss_rules_index
377
- if not _initialized or not faiss: return
378
- if os.path.exists(FAISS_MEMORY_PATH):
379
- _faiss_memory_index = faiss.read_index(FAISS_MEMORY_PATH)
380
- if os.path.exists(FAISS_RULES_PATH):
381
- _faiss_rules_index = faiss.read_index(FAISS_RULES_PATH)
 
32
  logger = logging.getLogger(__name__)
33
  # Suppress verbose logs from dependencies
34
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
35
+ if logging.getLogger(lib_name): logging.getLogger(lib_name).setLevel(logging.WARNING)
 
36
 
37
 
38
  # --- Configuration (Read directly from environment variables) ---
39
  STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper() #HF_DATASET, RAM, SQLITE
40
+ SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db")
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
+ HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain")
43
+ HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules")
44
 
45
  # --- Globals for RAG within this module ---
46
  _embedder = None
47
+ _dimension = 384
48
  _faiss_memory_index = None
49
+ _memory_items_list = []
50
  _faiss_rules_index = None
51
+ _rules_items_list = []
52
 
53
  _initialized = False
54
  _init_lock = threading.Lock()
55
 
 
56
  def _get_sqlite_connection():
57
+ if not sqlite3: raise ImportError("sqlite3 module is required for SQLite backend.")
 
58
  db_dir = os.path.dirname(SQLITE_DB_PATH)
59
+ if db_dir: os.makedirs(db_dir, exist_ok=True)
 
60
  return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
61
 
62
  def _init_sqlite_tables():
63
+ if STORAGE_BACKEND != "SQLITE" or not sqlite3: return
 
64
  try:
65
  with _get_sqlite_connection() as conn:
66
  cursor = conn.cursor()
67
+ cursor.execute("CREATE TABLE IF NOT EXISTS memories (id INTEGER PRIMARY KEY, memory_json TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)")
68
+ cursor.execute("CREATE TABLE IF NOT EXISTS rules (id INTEGER PRIMARY KEY, rule_text TEXT NOT NULL UNIQUE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)")
 
 
 
 
 
 
 
 
 
 
 
 
69
  conn.commit()
70
+ logger.info("SQLite tables checked/created.")
71
  except Exception as e:
72
  logger.error(f"SQLite table initialization error: {e}", exc_info=True)
73
 
 
74
  def _build_faiss_index(items_list, text_extraction_fn):
 
75
  if not _embedder:
76
  logger.error("Cannot build FAISS index: Embedder not available.")
77
  return None
78
 
79
  index = faiss.IndexFlatL2(_dimension)
80
+ if not items_list: return index
81
+
82
+ texts_to_embed, valid_items = [], []
83
+ for item in items_list:
84
+ try:
85
+ texts_to_embed.append(text_extraction_fn(item))
86
+ valid_items.append(item)
87
+ except (json.JSONDecodeError, TypeError) as e:
88
+ logger.warning(f"Skipping item during FAISS indexing due to processing error: {e}. Item: '{str(item)[:100]}...'")
89
+
90
+ if not texts_to_embed:
91
+ logger.warning("No valid items to embed for FAISS index after filtering.")
92
  return index
93
 
 
 
 
94
  try:
95
  embeddings = _embedder.encode(texts_to_embed, convert_to_tensor=False, show_progress_bar=False)
96
  embeddings_np = np.array(embeddings, dtype=np.float32)
97
+ if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(texts_to_embed):
98
  index.add(embeddings_np)
99
+ logger.info(f"FAISS index built with {index.ntotal} / {len(items_list)} items.")
100
+ # Important: The original items_list is returned so we can update the global list to match the index
101
+ return index, valid_items
102
  else:
103
+ logger.error(f"FAISS build failed: Embeddings shape error.")
104
+ return faiss.IndexFlatL2(_dimension), []
105
  except Exception as e:
106
  logger.error(f"Error building FAISS index: {e}", exc_info=True)
107
+ return faiss.IndexFlatL2(_dimension), []
 
 
108
 
 
109
  def initialize_memory_system():
110
  global _initialized, _embedder, _dimension, _faiss_memory_index, _memory_items_list, _faiss_rules_index, _rules_items_list
111
 
112
  with _init_lock:
113
+ if _initialized: return
 
114
 
115
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
116
  init_start_time = time.time()
117
 
118
+ if not all([SentenceTransformer, faiss, np]):
119
  logger.error("Core RAG libraries not available. Cannot initialize semantic memory.")
120
  return
121
 
122
  if not _embedder:
123
  try:
124
+ logger.info("Loading SentenceTransformer model...")
125
  _embedder = SentenceTransformer('all-MiniLM-L6-v2', cache_folder="./sentence_transformer_cache")
126
  _dimension = _embedder.get_sentence_embedding_dimension() or 384
127
  except Exception as e:
128
+ logger.critical(f"FATAL: Could not load SentenceTransformer model. Semantic search disabled. Error: {e}", exc_info=True)
129
+ return # Stop initialization if embedder fails
130
 
131
  if STORAGE_BACKEND == "SQLITE": _init_sqlite_tables()
132
 
133
+ # Load raw data
134
+ raw_mems = []
135
  if STORAGE_BACKEND == "SQLITE":
136
+ try: raw_mems = [row[0] for row in _get_sqlite_connection().execute("SELECT memory_json FROM memories")]
137
  except Exception as e: logger.error(f"Error loading memories from SQLite: {e}")
138
  elif STORAGE_BACKEND == "HF_DATASET":
139
  try:
 
140
  dataset = load_dataset(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
141
  if "train" in dataset and "memory_json" in dataset["train"].column_names:
142
+ raw_mems = [m for m in dataset["train"]["memory_json"] if isinstance(m, str) and m.strip()]
 
 
143
  except Exception as e: logger.error(f"Error loading memories from HF Dataset: {e}", exc_info=True)
144
 
145
+ # Build Memory Index and get validated list
146
+ mem_index, valid_mems = _build_faiss_index(raw_mems, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
147
+ _faiss_memory_index = mem_index
148
+ _memory_items_list = valid_mems # Use the validated list that matches the index
149
+ logger.info(f"Loaded and indexed {len(_memory_items_list)} memories.")
150
+
151
+ # Load Rules
152
+ raw_rules = []
 
 
153
  if STORAGE_BACKEND == "SQLITE":
154
+ try: raw_rules = [row[0] for row in _get_sqlite_connection().execute("SELECT rule_text FROM rules")]
155
  except Exception as e: logger.error(f"Error loading rules from SQLite: {e}")
156
  elif STORAGE_BACKEND == "HF_DATASET":
157
  try:
 
158
  dataset = load_dataset(HF_RULES_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
159
  if "train" in dataset and "rule_text" in dataset["train"].column_names:
160
+ raw_rules = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
 
 
161
  except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
162
 
163
+ # Build Rules Index and get validated list
164
+ rule_index, valid_rules = _build_faiss_index(sorted(list(set(raw_rules))), lambda r: r)
165
+ _faiss_rules_index = rule_index
166
+ _rules_items_list = valid_rules # Use the validated list that matches the index
167
+ logger.info(f"Loaded and indexed {len(_rules_items_list)} rules.")
168
 
169
+ # Only mark as initialized if the core components are ready
170
+ if _embedder and _faiss_memory_index is not None and _faiss_rules_index is not None:
171
+ _initialized = True
172
+ logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
173
+ else:
174
+ logger.error("Memory system initialization failed. Core components are not ready.")
175
 
176
  def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
 
177
  if not index or index.ntotal != len(items_list):
178
+ logger.warning(f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). Rebuilding...")
179
+ new_index, valid_items = _build_faiss_index(items_list, text_extraction_fn)
180
+ # This function is now stateful and modifies the global list to match the new index
181
+ if isinstance(items_list, list) and isinstance(valid_items, list):
182
+ # A bit of a hack to update the global list from here
183
+ items_list.clear()
184
+ items_list.extend(valid_items)
185
+ return new_index
186
  return index
187
 
 
 
188
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
 
189
  if not _initialized: initialize_memory_system()
190
+ if not _embedder or _faiss_memory_index is None: return False, "Memory system not ready."
 
191
 
192
  memory_obj = {"user_input": user_input, "metrics": metrics, "bot_response": bot_response, "timestamp": datetime.utcnow().isoformat()}
193
  memory_json_str = json.dumps(memory_obj)
 
195
 
196
  try:
197
  embedding = _embedder.encode([text_to_embed], convert_to_tensor=False)
198
+ _faiss_memory_index.add(np.array(embedding, dtype=np.float32))
 
 
199
  _memory_items_list.append(memory_json_str)
200
 
201
  if STORAGE_BACKEND == "SQLITE":
202
+ with _get_sqlite_connection() as conn: conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json_str,)); conn.commit()
 
203
  elif STORAGE_BACKEND == "HF_DATASET":
204
  Dataset.from_dict({"memory_json": list(_memory_items_list)}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
205
 
206
+ return True, "Memory added."
 
207
  except Exception as e:
208
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
209
+ return False, f"Error: {e}"
210
 
211
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
212
  global _faiss_memory_index
213
  if not _initialized: initialize_memory_system()
214
+ if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
 
 
 
 
 
 
 
 
 
215
 
216
+ _faiss_memory_index = _verify_and_rebuild_if_needed(_faiss_memory_index, _memory_items_list, lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}")
217
+ if not _faiss_memory_index or _faiss_memory_index.ntotal == 0: return []
218
+
219
  try:
220
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
221
+ distances, indices = _faiss_memory_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_memory_index.ntotal))
222
+ return [json.loads(_memory_items_list[i]) for i in indices[0] if 0 <= i < len(_memory_items_list)]
 
 
 
 
223
  except Exception as e:
224
+ logger.error(f"Error retrieving memories: {e}", exc_info=True)
225
  return []
226
 
 
 
227
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
 
228
  if not _initialized: initialize_memory_system()
229
+ if not _embedder or _faiss_rules_index is None: return False, "Rule system not ready."
230
 
231
  rule_text = rule_text.strip()
232
+ if not rule_text or rule_text in _rules_items_list: return False, "duplicate or invalid"
233
+ if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", rule_text, re.I): return False, "Invalid format."
 
 
234
 
235
  try:
236
  embedding = _embedder.encode([rule_text], convert_to_tensor=False)
237
+ _faiss_rules_index.add(np.array(embedding, dtype=np.float32))
 
238
  _rules_items_list.append(rule_text)
239
  _rules_items_list.sort()
240
 
241
  if STORAGE_BACKEND == "SQLITE":
242
+ with _get_sqlite_connection() as conn: conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
 
243
  elif STORAGE_BACKEND == "HF_DATASET":
244
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
245
+ return True, "Rule added."
246
  except Exception as e:
247
+ logger.error(f"Error adding rule: {e}", exc_info=True)
248
+ return False, f"Error: {e}"
249
 
250
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
251
  global _faiss_rules_index
252
  if not _initialized: initialize_memory_system()
253
+ if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
254
+
255
  _faiss_rules_index = _verify_and_rebuild_if_needed(_faiss_rules_index, _rules_items_list, lambda r: r)
 
256
  if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
257
+
258
  try:
259
  query_embedding = _embedder.encode([query], convert_to_tensor=False)
260
+ distances, indices = _faiss_rules_index.search(np.array(query_embedding, dtype=np.float32), min(k, _faiss_rules_index.ntotal))
 
261
  return [_rules_items_list[i] for i in indices[0] if 0 <= i < len(_rules_items_list)]
262
  except Exception as e:
263
+ logger.error(f"Error retrieving rules: {e}", exc_info=True)
264
  return []
265
 
266
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
 
267
  if not _initialized: initialize_memory_system()
268
  rule_text_to_delete = rule_text_to_delete.strip()
269
  if rule_text_to_delete not in _rules_items_list: return False
270
  try:
271
+ # Rebuild index and list without the deleted rule
272
+ new_list = [r for r in _rules_items_list if r != rule_text_to_delete]
273
+ _faiss_rules_index, _rules_items_list = _build_faiss_index(new_list, lambda r: r)
274
 
275
  if STORAGE_BACKEND == "SQLITE":
276
+ with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
 
277
  elif STORAGE_BACKEND == "HF_DATASET":
278
  Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
279
  return True
280
  except Exception as e:
281
+ logger.error(f"Error removing rule: {e}", exc_info=True)
282
  return False
283
 
 
284
  def get_all_rules_cached() -> list[str]:
285
  if not _initialized: initialize_memory_system()
286
  return list(_rules_items_list)
287
 
288
  def get_all_memories_cached() -> list[dict]:
289
  if not _initialized: initialize_memory_system()
290
+ valid_mems = []
291
+ for m_str in _memory_items_list:
292
+ try:
293
+ valid_mems.append(json.loads(m_str))
294
+ except json.JSONDecodeError:
295
+ continue # Skip corrupted data for UI display
296
+ return valid_mems
297
 
298
  def clear_all_memory_data_backend() -> bool:
299
  global _memory_items_list, _faiss_memory_index
300
  if not _initialized: initialize_memory_system()
301
+ _memory_items_list.clear()
302
  if _faiss_memory_index: _faiss_memory_index.reset()
303
  try:
304
  if STORAGE_BACKEND == "SQLITE":
305
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
306
  elif STORAGE_BACKEND == "HF_DATASET":
307
  Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
 
308
  return True
309
  except Exception as e:
310
+ logger.error(f"Error clearing memory data: {e}"); return False
311
 
312
  def clear_all_rules_data_backend() -> bool:
313
  global _rules_items_list, _faiss_rules_index
314
  if not _initialized: initialize_memory_system()
315
+ _rules_items_list.clear()
316
  if _faiss_rules_index: _faiss_rules_index.reset()
317
  try:
318
  if STORAGE_BACKEND == "SQLITE":
319
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules"); conn.commit()
320
  elif STORAGE_BACKEND == "HF_DATASET":
321
  Dataset.from_dict({"rule_text": []}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
 
322
  return True
323
  except Exception as e:
324
+ logger.error(f"Error clearing rules data: {e}"); return False
 
 
 
325
 
326
  def save_faiss_indices_to_disk():
327
+ # This function is primarily for the RAM backend, which is not the focus here.
328
+ # It's kept for compatibility.
329
  if not _initialized or not faiss: return
330
+ faiss_dir = "app_data/faiss_indices"
331
+ os.makedirs(faiss_dir, exist_ok=True)
332
+ if _faiss_memory_index: faiss.write_index(_faiss_memory_index, os.path.join(faiss_dir, "memory_index.faiss"))
333
+ if _faiss_rules_index: faiss.write_index(_faiss_rules_index, os.path.join(faiss_dir, "rules_index.faiss"))