broadfield-dev commited on
Commit
38b39e1
·
verified ·
1 Parent(s): caeccf9

Update memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +251 -77
memory_logic.py CHANGED
@@ -1,4 +1,3 @@
1
- # memory_logic.py
2
  import os
3
  import json
4
  import time
@@ -7,7 +6,6 @@ import logging
7
  import re
8
  import threading
9
 
10
- # Conditionally import heavy dependencies
11
  try:
12
  from sentence_transformers import SentenceTransformer
13
  import faiss
@@ -30,38 +28,34 @@ except ImportError:
30
 
31
 
32
  logger = logging.getLogger(__name__)
33
- # Suppress verbose logs from dependencies
34
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
35
- if logging.getLogger(lib_name): # Check if logger exists
36
  logging.getLogger(lib_name).setLevel(logging.WARNING)
37
 
38
 
39
- # --- Configuration (Read directly from environment variables) ---
40
- STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper() #HF_DATASET, RAM, SQLITE
41
- SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db") # Changed default path
42
  HF_TOKEN = os.getenv("HF_TOKEN")
43
- HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain") # Example
44
- HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules") # Example
45
 
46
- # --- Globals for RAG within this module ---
47
  _embedder = None
48
- _dimension = 384 # Default, will be set by embedder
49
  _faiss_memory_index = None
50
- _memory_items_list = [] # Stores JSON strings of memory objects for RAM, or loaded from DB/HF
51
  _faiss_rules_index = None
52
- _rules_items_list = [] # Stores rule text strings
53
 
54
  _initialized = False
55
  _init_lock = threading.Lock()
56
 
57
- # --- Helper: SQLite Connection ---
58
  def _get_sqlite_connection():
59
  if not sqlite3:
60
  raise ImportError("sqlite3 module is required for SQLite backend but not found.")
61
  db_dir = os.path.dirname(SQLITE_DB_PATH)
62
  if db_dir and not os.path.exists(db_dir):
63
  os.makedirs(db_dir, exist_ok=True)
64
- return sqlite3.connect(SQLITE_DB_PATH, timeout=10) # Added timeout
65
 
66
  def _init_sqlite_tables():
67
  if STORAGE_BACKEND != "SQLITE" or not sqlite3:
@@ -69,22 +63,17 @@ def _init_sqlite_tables():
69
  try:
70
  with _get_sqlite_connection() as conn:
71
  cursor = conn.cursor()
72
- # Stores JSON string of the memory object
73
  cursor.execute("""
74
  CREATE TABLE IF NOT EXISTS memories (
75
  id INTEGER PRIMARY KEY AUTOINCREMENT,
76
  memory_json TEXT NOT NULL,
77
- # Optionally add embedding here if not using separate FAISS index
78
- # embedding BLOB,
79
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
80
  )
81
  """)
82
- # Stores the rule text directly
83
  cursor.execute("""
84
  CREATE TABLE IF NOT EXISTS rules (
85
  id INTEGER PRIMARY KEY AUTOINCREMENT,
86
  rule_text TEXT NOT NULL UNIQUE,
87
- # embedding BLOB,
88
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
89
  )
90
  """)
@@ -93,7 +82,6 @@ def _init_sqlite_tables():
93
  except Exception as e:
94
  logger.error(f"SQLite table initialization error: {e}", exc_info=True)
95
 
96
- # --- Initialization ---
97
  def initialize_memory_system():
98
  global _initialized, _embedder, _dimension, _faiss_memory_index, _memory_items_list, _faiss_rules_index, _rules_items_list
99
 
@@ -105,10 +93,9 @@ def initialize_memory_system():
105
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
106
  init_start_time = time.time()
107
 
108
- # 1. Load Sentence Transformer Model (always needed for semantic operations)
109
  if not SentenceTransformer or not faiss or not np:
110
  logger.error("Core RAG libraries (SentenceTransformers, FAISS, NumPy) not available. Cannot initialize semantic memory.")
111
- _initialized = False # Mark as not properly initialized
112
  return
113
 
114
  if not _embedder:
@@ -120,17 +107,15 @@ def initialize_memory_system():
120
  except Exception as e:
121
  logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
122
  _initialized = False
123
- return # Cannot proceed without embedder
124
 
125
- # 2. Initialize SQLite if used
126
  if STORAGE_BACKEND == "SQLITE":
127
  _init_sqlite_tables()
128
 
129
- # 3. Load Memories
130
  logger.info("Loading memories...")
131
  temp_memories_json = []
132
  if STORAGE_BACKEND == "RAM":
133
- _memory_items_list = [] # Start fresh for RAM backend
134
  elif STORAGE_BACKEND == "SQLITE" and sqlite3:
135
  try:
136
  with _get_sqlite_connection() as conn:
@@ -139,8 +124,8 @@ def initialize_memory_system():
139
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset and load_dataset:
140
  try:
141
  logger.info(f"Attempting to load memories from HF Dataset: {HF_MEMORY_DATASET_REPO}")
142
- dataset = load_dataset(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True) # Add download_mode if needed
143
- if "train" in dataset and "memory_json" in dataset["train"].column_names: # Assuming 'memory_json' column
144
  temp_memories_json = [m_json for m_json in dataset["train"]["memory_json"] if isinstance(m_json, str)]
145
  else: logger.warning(f"HF Dataset {HF_MEMORY_DATASET_REPO} for memories not found or 'memory_json' column missing.")
146
  except Exception as e: logger.error(f"Error loading memories from HF Dataset ({HF_MEMORY_DATASET_REPO}): {e}")
@@ -148,16 +133,13 @@ def initialize_memory_system():
148
  _memory_items_list = temp_memories_json
149
  logger.info(f"Loaded {len(_memory_items_list)} memory items from {STORAGE_BACKEND}.")
150
 
151
- # 4. Build/Load FAISS Memory Index
152
  _faiss_memory_index = faiss.IndexFlatL2(_dimension)
153
  if _memory_items_list:
154
  logger.info(f"Building FAISS index for {len(_memory_items_list)} memories...")
155
- # Extract text to embed from memory JSON objects
156
  texts_to_embed_mem = []
157
  for mem_json_str in _memory_items_list:
158
  try:
159
  mem_obj = json.loads(mem_json_str)
160
- # Consistent embedding strategy: user input + bot response + takeaway
161
  text = f"User: {mem_obj.get('user_input','')}\nAI: {mem_obj.get('bot_response','')}\nTakeaway: {mem_obj.get('metrics',{}).get('takeaway','N/A')}"
162
  texts_to_embed_mem.append(text)
163
  except json.JSONDecodeError:
@@ -165,7 +147,7 @@ def initialize_memory_system():
165
 
166
  if texts_to_embed_mem:
167
  try:
168
- embeddings = _embedder.encode(texts_to_embed_mem, convert_to_tensor=False, show_progress_bar=False) # convert_to_numpy=True
169
  embeddings_np = np.array(embeddings, dtype=np.float32)
170
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(texts_to_embed_mem) and embeddings_np.shape[1] == _dimension:
171
  _faiss_memory_index.add(embeddings_np)
@@ -173,8 +155,6 @@ def initialize_memory_system():
173
  except Exception as e_faiss_mem: logger.error(f"Error building FAISS memory index: {e_faiss_mem}")
174
  logger.info(f"FAISS memory index built. Total items: {_faiss_memory_index.ntotal if _faiss_memory_index else 'N/A'}")
175
 
176
-
177
- # 5. Load Rules
178
  logger.info("Loading rules...")
179
  temp_rules_text = []
180
  if STORAGE_BACKEND == "RAM":
@@ -193,14 +173,13 @@ def initialize_memory_system():
193
  else: logger.warning(f"HF Dataset {HF_RULES_DATASET_REPO} for rules not found or 'rule_text' column missing.")
194
  except Exception as e: logger.error(f"Error loading rules from HF Dataset ({HF_RULES_DATASET_REPO}): {e}")
195
 
196
- _rules_items_list = sorted(list(set(temp_rules_text))) # Ensure unique and sorted
197
  logger.info(f"Loaded {len(_rules_items_list)} rule items from {STORAGE_BACKEND}.")
198
 
199
- # 6. Build/Load FAISS Rules Index
200
  _faiss_rules_index = faiss.IndexFlatL2(_dimension)
201
  if _rules_items_list:
202
  logger.info(f"Building FAISS index for {len(_rules_items_list)} rules...")
203
- if _rules_items_list: # Check again in case it became empty after filtering
204
  try:
205
  embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False, show_progress_bar=False)
206
  embeddings_np = np.array(embeddings, dtype=np.float32)
@@ -214,9 +193,7 @@ def initialize_memory_system():
214
  logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
215
 
216
 
217
- # --- Memory Operations (Semantic) ---
218
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
219
- """Adds a memory entry to the configured backend and FAISS index."""
220
  global _memory_items_list, _faiss_memory_index
221
  if not _initialized: initialize_memory_system()
222
  if not _embedder or not _faiss_memory_index:
@@ -240,31 +217,25 @@ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple
240
  logger.error(f"Memory embedding shape error: {embedding_np.shape}. Expected (1, {_dimension})")
241
  return False, "Embedding shape error."
242
 
243
- # Add to FAISS
244
  _faiss_memory_index.add(embedding_np)
245
 
246
- # Add to in-memory list
247
  _memory_items_list.append(memory_json_str)
248
 
249
- # Add to persistent storage
250
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
251
  with _get_sqlite_connection() as conn:
252
  conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json_str,))
253
  conn.commit()
254
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
255
- # This can be slow, consider batching or async push
256
  logger.info(f"Pushing {len(_memory_items_list)} memories to HF Hub: {HF_MEMORY_DATASET_REPO}")
257
- Dataset.from_dict({"memory_json": list(_memory_items_list)}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True) # Ensure 'private' as needed
258
 
259
  logger.info(f"Added memory. RAM: {len(_memory_items_list)}, FAISS: {_faiss_memory_index.ntotal}")
260
  return True, "Memory added successfully."
261
  except Exception as e:
262
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
263
- # TODO: Potential rollback logic if FAISS add succeeded but backend failed (complex)
264
  return False, f"Error adding memory: {e}"
265
 
266
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
267
- """Retrieves k most relevant memories using semantic search."""
268
  if not _initialized: initialize_memory_system()
269
  if not _embedder or not _faiss_memory_index or _faiss_memory_index.ntotal == 0:
270
  logger.debug("Cannot retrieve memories: Embedder, FAISS index not ready, or index is empty.")
@@ -297,9 +268,7 @@ def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
297
  return []
298
 
299
 
300
- # --- Rule (Insight) Operations (Semantic) ---
301
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
302
- """Adds a rule if valid and not a duplicate. Updates backend and FAISS."""
303
  global _rules_items_list, _faiss_rules_index
304
  if not _initialized: initialize_memory_system()
305
  if not _embedder or not _faiss_rules_index:
@@ -335,15 +304,9 @@ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
335
  return True, "Rule added successfully."
336
  except Exception as e:
337
  logger.error(f"Error adding rule entry: {e}", exc_info=True)
338
- # Basic rollback if FAISS add succeeded
339
- if rule_text in _rules_items_list and _faiss_rules_index.ntotal > 0: # Crude check
340
- # A full rollback would involve rebuilding FAISS index from _rules_items_list before append.
341
- # For simplicity, this is omitted here. State could be inconsistent on error.
342
- pass
343
  return False, f"Error adding rule: {e}"
344
 
345
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
346
- """Retrieves k most relevant rules using semantic search."""
347
  if not _initialized: initialize_memory_system()
348
  if not _embedder or not _faiss_rules_index or _faiss_rules_index.ntotal == 0:
349
  return []
@@ -362,35 +325,31 @@ def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
362
  return []
363
 
364
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
365
- """Removes a rule from backend and rebuilds FAISS for rules."""
366
  global _rules_items_list, _faiss_rules_index
367
  if not _initialized: initialize_memory_system()
368
  if not _embedder or not _faiss_rules_index: return False
369
 
370
  rule_text_to_delete = rule_text_to_delete.strip()
371
  if rule_text_to_delete not in _rules_items_list:
372
- return False # Not found
373
 
374
  try:
375
  _rules_items_list.remove(rule_text_to_delete)
376
- _rules_items_list.sort() # Maintain sorted order
377
 
378
- # Rebuild FAISS index for rules (simplest way to ensure consistency after removal)
379
  new_faiss_rules_index = faiss.IndexFlatL2(_dimension)
380
  if _rules_items_list:
381
  embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False)
382
  embeddings_np = np.array(embeddings, dtype=np.float32)
383
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(_rules_items_list) and embeddings_np.shape[1] == _dimension:
384
  new_faiss_rules_index.add(embeddings_np)
385
- else: # Should not happen if list is consistent
386
  logger.error("Error rebuilding FAISS for rules after removal: Embedding shape error. State might be inconsistent.")
387
- # Attempt to revert _rules_items_list (add back the rule)
388
  _rules_items_list.append(rule_text_to_delete)
389
  _rules_items_list.sort()
390
- return False # Indicate failure
391
  _faiss_rules_index = new_faiss_rules_index
392
 
393
- # Remove from persistent storage
394
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
395
  with _get_sqlite_connection() as conn:
396
  conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,))
@@ -402,25 +361,21 @@ def remove_rule_entry(rule_text_to_delete: str) -> bool:
402
  return True
403
  except Exception as e:
404
  logger.error(f"Error removing rule entry: {e}", exc_info=True)
405
- # Potential partial failure, state might be inconsistent.
406
  return False
407
 
408
- # --- Utility functions to get all data (for UI display, etc.) ---
409
  def get_all_rules_cached() -> list[str]:
410
  if not _initialized: initialize_memory_system()
411
  return list(_rules_items_list)
412
 
413
  def get_all_memories_cached() -> list[dict]:
414
  if not _initialized: initialize_memory_system()
415
- # Convert JSON strings to dicts for easier use by UI
416
  mem_dicts = []
417
  for mem_json_str in _memory_items_list:
418
  try: mem_dicts.append(json.loads(mem_json_str))
419
- except: pass # Ignore parse errors for display
420
  return mem_dicts
421
 
422
  def clear_all_memory_data_backend() -> bool:
423
- """Clears all memories from backend and resets in-memory FAISS/list."""
424
  global _memory_items_list, _faiss_memory_index
425
  if not _initialized: initialize_memory_system()
426
 
@@ -429,11 +384,10 @@ def clear_all_memory_data_backend() -> bool:
429
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
430
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
431
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
432
- # Deleting from HF usually means pushing an empty dataset
433
  Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
434
 
435
  _memory_items_list = []
436
- if _faiss_memory_index: _faiss_memory_index.reset() # Clear FAISS index
437
  logger.info("All memories cleared from backend and in-memory stores.")
438
  except Exception as e:
439
  logger.error(f"Error clearing all memory data: {e}")
@@ -441,7 +395,6 @@ def clear_all_memory_data_backend() -> bool:
441
  return success
442
 
443
  def clear_all_rules_data_backend() -> bool:
444
- """Clears all rules from backend and resets in-memory FAISS/list."""
445
  global _rules_items_list, _faiss_rules_index
446
  if not _initialized: initialize_memory_system()
447
 
@@ -460,7 +413,231 @@ def clear_all_rules_data_backend() -> bool:
460
  success = False
461
  return success
462
 
463
- # Optional: Function to save FAISS indices to disk (from ai-learn, if needed for persistence between app runs with RAM backend)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  FAISS_MEMORY_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "memory_index.faiss")
465
  FAISS_RULES_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "rules_index.faiss")
466
 
@@ -486,12 +663,11 @@ def load_faiss_indices_from_disk():
486
  global _faiss_memory_index, _faiss_rules_index
487
  if not _initialized or not faiss: return
488
 
489
- if os.path.exists(FAISS_MEMORY_PATH) and _faiss_memory_index: # Check if index object exists
490
  try:
491
  logger.info(f"Loading memory FAISS index from {FAISS_MEMORY_PATH}...")
492
  _faiss_memory_index = faiss.read_index(FAISS_MEMORY_PATH)
493
  logger.info(f"Memory FAISS index loaded ({_faiss_memory_index.ntotal} items).")
494
- # Consistency check: FAISS ntotal vs len(_memory_items_list)
495
  if _faiss_memory_index.ntotal != len(_memory_items_list) and len(_memory_items_list) > 0:
496
  logger.warning(f"Memory FAISS index count ({_faiss_memory_index.ntotal}) differs from loaded texts ({len(_memory_items_list)}). Consider rebuilding FAISS.")
497
  except Exception as e: logger.error(f"Error loading memory FAISS index: {e}. Will use fresh index.")
@@ -503,6 +679,4 @@ def load_faiss_indices_from_disk():
503
  logger.info(f"Rules FAISS index loaded ({_faiss_rules_index.ntotal} items).")
504
  if _faiss_rules_index.ntotal != len(_rules_items_list) and len(_rules_items_list) > 0:
505
  logger.warning(f"Rules FAISS index count ({_faiss_rules_index.ntotal}) differs from loaded texts ({len(_rules_items_list)}). Consider rebuilding FAISS.")
506
- except Exception as e: logger.error(f"Error loading rules FAISS index: {e}. Will use fresh index.")
507
-
508
-
 
 
1
  import os
2
  import json
3
  import time
 
6
  import re
7
  import threading
8
 
 
9
  try:
10
  from sentence_transformers import SentenceTransformer
11
  import faiss
 
28
 
29
 
30
  logger = logging.getLogger(__name__)
 
31
  for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
32
+ if logging.getLogger(lib_name):
33
  logging.getLogger(lib_name).setLevel(logging.WARNING)
34
 
35
 
36
+ STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper()
37
+ SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db")
 
38
  HF_TOKEN = os.getenv("HF_TOKEN")
39
+ HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain")
40
+ HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules")
41
 
 
42
  _embedder = None
43
+ _dimension = 384
44
  _faiss_memory_index = None
45
+ _memory_items_list = []
46
  _faiss_rules_index = None
47
+ _rules_items_list = []
48
 
49
  _initialized = False
50
  _init_lock = threading.Lock()
51
 
 
52
  def _get_sqlite_connection():
53
  if not sqlite3:
54
  raise ImportError("sqlite3 module is required for SQLite backend but not found.")
55
  db_dir = os.path.dirname(SQLITE_DB_PATH)
56
  if db_dir and not os.path.exists(db_dir):
57
  os.makedirs(db_dir, exist_ok=True)
58
+ return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
59
 
60
  def _init_sqlite_tables():
61
  if STORAGE_BACKEND != "SQLITE" or not sqlite3:
 
63
  try:
64
  with _get_sqlite_connection() as conn:
65
  cursor = conn.cursor()
 
66
  cursor.execute("""
67
  CREATE TABLE IF NOT EXISTS memories (
68
  id INTEGER PRIMARY KEY AUTOINCREMENT,
69
  memory_json TEXT NOT NULL,
 
 
70
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
71
  )
72
  """)
 
73
  cursor.execute("""
74
  CREATE TABLE IF NOT EXISTS rules (
75
  id INTEGER PRIMARY KEY AUTOINCREMENT,
76
  rule_text TEXT NOT NULL UNIQUE,
 
77
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
78
  )
79
  """)
 
82
  except Exception as e:
83
  logger.error(f"SQLite table initialization error: {e}", exc_info=True)
84
 
 
85
  def initialize_memory_system():
86
  global _initialized, _embedder, _dimension, _faiss_memory_index, _memory_items_list, _faiss_rules_index, _rules_items_list
87
 
 
93
  logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
94
  init_start_time = time.time()
95
 
 
96
  if not SentenceTransformer or not faiss or not np:
97
  logger.error("Core RAG libraries (SentenceTransformers, FAISS, NumPy) not available. Cannot initialize semantic memory.")
98
+ _initialized = False
99
  return
100
 
101
  if not _embedder:
 
107
  except Exception as e:
108
  logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
109
  _initialized = False
110
+ return
111
 
 
112
  if STORAGE_BACKEND == "SQLITE":
113
  _init_sqlite_tables()
114
 
 
115
  logger.info("Loading memories...")
116
  temp_memories_json = []
117
  if STORAGE_BACKEND == "RAM":
118
+ _memory_items_list = []
119
  elif STORAGE_BACKEND == "SQLITE" and sqlite3:
120
  try:
121
  with _get_sqlite_connection() as conn:
 
124
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset and load_dataset:
125
  try:
126
  logger.info(f"Attempting to load memories from HF Dataset: {HF_MEMORY_DATASET_REPO}")
127
+ dataset = load_dataset(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
128
+ if "train" in dataset and "memory_json" in dataset["train"].column_names:
129
  temp_memories_json = [m_json for m_json in dataset["train"]["memory_json"] if isinstance(m_json, str)]
130
  else: logger.warning(f"HF Dataset {HF_MEMORY_DATASET_REPO} for memories not found or 'memory_json' column missing.")
131
  except Exception as e: logger.error(f"Error loading memories from HF Dataset ({HF_MEMORY_DATASET_REPO}): {e}")
 
133
  _memory_items_list = temp_memories_json
134
  logger.info(f"Loaded {len(_memory_items_list)} memory items from {STORAGE_BACKEND}.")
135
 
 
136
  _faiss_memory_index = faiss.IndexFlatL2(_dimension)
137
  if _memory_items_list:
138
  logger.info(f"Building FAISS index for {len(_memory_items_list)} memories...")
 
139
  texts_to_embed_mem = []
140
  for mem_json_str in _memory_items_list:
141
  try:
142
  mem_obj = json.loads(mem_json_str)
 
143
  text = f"User: {mem_obj.get('user_input','')}\nAI: {mem_obj.get('bot_response','')}\nTakeaway: {mem_obj.get('metrics',{}).get('takeaway','N/A')}"
144
  texts_to_embed_mem.append(text)
145
  except json.JSONDecodeError:
 
147
 
148
  if texts_to_embed_mem:
149
  try:
150
+ embeddings = _embedder.encode(texts_to_embed_mem, convert_to_tensor=False, show_progress_bar=False)
151
  embeddings_np = np.array(embeddings, dtype=np.float32)
152
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(texts_to_embed_mem) and embeddings_np.shape[1] == _dimension:
153
  _faiss_memory_index.add(embeddings_np)
 
155
  except Exception as e_faiss_mem: logger.error(f"Error building FAISS memory index: {e_faiss_mem}")
156
  logger.info(f"FAISS memory index built. Total items: {_faiss_memory_index.ntotal if _faiss_memory_index else 'N/A'}")
157
 
 
 
158
  logger.info("Loading rules...")
159
  temp_rules_text = []
160
  if STORAGE_BACKEND == "RAM":
 
173
  else: logger.warning(f"HF Dataset {HF_RULES_DATASET_REPO} for rules not found or 'rule_text' column missing.")
174
  except Exception as e: logger.error(f"Error loading rules from HF Dataset ({HF_RULES_DATASET_REPO}): {e}")
175
 
176
+ _rules_items_list = sorted(list(set(temp_rules_text)))
177
  logger.info(f"Loaded {len(_rules_items_list)} rule items from {STORAGE_BACKEND}.")
178
 
 
179
  _faiss_rules_index = faiss.IndexFlatL2(_dimension)
180
  if _rules_items_list:
181
  logger.info(f"Building FAISS index for {len(_rules_items_list)} rules...")
182
+ if _rules_items_list:
183
  try:
184
  embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False, show_progress_bar=False)
185
  embeddings_np = np.array(embeddings, dtype=np.float32)
 
193
  logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
194
 
195
 
 
196
  def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
 
197
  global _memory_items_list, _faiss_memory_index
198
  if not _initialized: initialize_memory_system()
199
  if not _embedder or not _faiss_memory_index:
 
217
  logger.error(f"Memory embedding shape error: {embedding_np.shape}. Expected (1, {_dimension})")
218
  return False, "Embedding shape error."
219
 
 
220
  _faiss_memory_index.add(embedding_np)
221
 
 
222
  _memory_items_list.append(memory_json_str)
223
 
 
224
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
225
  with _get_sqlite_connection() as conn:
226
  conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json_str,))
227
  conn.commit()
228
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
 
229
  logger.info(f"Pushing {len(_memory_items_list)} memories to HF Hub: {HF_MEMORY_DATASET_REPO}")
230
+ Dataset.from_dict({"memory_json": list(_memory_items_list)}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
231
 
232
  logger.info(f"Added memory. RAM: {len(_memory_items_list)}, FAISS: {_faiss_memory_index.ntotal}")
233
  return True, "Memory added successfully."
234
  except Exception as e:
235
  logger.error(f"Error adding memory entry: {e}", exc_info=True)
 
236
  return False, f"Error adding memory: {e}"
237
 
238
  def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
 
239
  if not _initialized: initialize_memory_system()
240
  if not _embedder or not _faiss_memory_index or _faiss_memory_index.ntotal == 0:
241
  logger.debug("Cannot retrieve memories: Embedder, FAISS index not ready, or index is empty.")
 
268
  return []
269
 
270
 
 
271
  def add_rule_entry(rule_text: str) -> tuple[bool, str]:
 
272
  global _rules_items_list, _faiss_rules_index
273
  if not _initialized: initialize_memory_system()
274
  if not _embedder or not _faiss_rules_index:
 
304
  return True, "Rule added successfully."
305
  except Exception as e:
306
  logger.error(f"Error adding rule entry: {e}", exc_info=True)
 
 
 
 
 
307
  return False, f"Error adding rule: {e}"
308
 
309
  def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
 
310
  if not _initialized: initialize_memory_system()
311
  if not _embedder or not _faiss_rules_index or _faiss_rules_index.ntotal == 0:
312
  return []
 
325
  return []
326
 
327
  def remove_rule_entry(rule_text_to_delete: str) -> bool:
 
328
  global _rules_items_list, _faiss_rules_index
329
  if not _initialized: initialize_memory_system()
330
  if not _embedder or not _faiss_rules_index: return False
331
 
332
  rule_text_to_delete = rule_text_to_delete.strip()
333
  if rule_text_to_delete not in _rules_items_list:
334
+ return False
335
 
336
  try:
337
  _rules_items_list.remove(rule_text_to_delete)
338
+ _rules_items_list.sort()
339
 
 
340
  new_faiss_rules_index = faiss.IndexFlatL2(_dimension)
341
  if _rules_items_list:
342
  embeddings = _embedder.encode(_rules_items_list, convert_to_tensor=False)
343
  embeddings_np = np.array(embeddings, dtype=np.float32)
344
  if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(_rules_items_list) and embeddings_np.shape[1] == _dimension:
345
  new_faiss_rules_index.add(embeddings_np)
346
+ else:
347
  logger.error("Error rebuilding FAISS for rules after removal: Embedding shape error. State might be inconsistent.")
 
348
  _rules_items_list.append(rule_text_to_delete)
349
  _rules_items_list.sort()
350
+ return False
351
  _faiss_rules_index = new_faiss_rules_index
352
 
 
353
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
354
  with _get_sqlite_connection() as conn:
355
  conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,))
 
361
  return True
362
  except Exception as e:
363
  logger.error(f"Error removing rule entry: {e}", exc_info=True)
 
364
  return False
365
 
 
366
  def get_all_rules_cached() -> list[str]:
367
  if not _initialized: initialize_memory_system()
368
  return list(_rules_items_list)
369
 
370
  def get_all_memories_cached() -> list[dict]:
371
  if not _initialized: initialize_memory_system()
 
372
  mem_dicts = []
373
  for mem_json_str in _memory_items_list:
374
  try: mem_dicts.append(json.loads(mem_json_str))
375
+ except: pass
376
  return mem_dicts
377
 
378
  def clear_all_memory_data_backend() -> bool:
 
379
  global _memory_items_list, _faiss_memory_index
380
  if not _initialized: initialize_memory_system()
381
 
 
384
  if STORAGE_BACKEND == "SQLITE" and sqlite3:
385
  with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
386
  elif STORAGE_BACKEND == "HF_DATASET" and HF_TOKEN and Dataset:
 
387
  Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
388
 
389
  _memory_items_list = []
390
+ if _faiss_memory_index: _faiss_memory_index.reset()
391
  logger.info("All memories cleared from backend and in-memory stores.")
392
  except Exception as e:
393
  logger.error(f"Error clearing all memory data: {e}")
 
395
  return success
396
 
397
  def clear_all_rules_data_backend() -> bool:
 
398
  global _rules_items_list, _faiss_rules_index
399
  if not _initialized: initialize_memory_system()
400
 
 
413
  success = False
414
  return success
415
 
416
+ def load_rules_from_file(filepath: str | None):
417
+ if not filepath:
418
+ logger.info("LOAD_RULES_FILE environment variable not set. Skipping rules loading from file.")
419
+ return 0, 0, 0
420
+
421
+ if not os.path.exists(filepath):
422
+ logger.warning(f"LOAD_RULES: Specified rules file not found: {filepath}. Skipping loading.")
423
+ return 0, 0, 0
424
+
425
+ added_count, skipped_count, error_count = 0, 0, 0
426
+ potential_rules = []
427
+
428
+ try:
429
+ with open(filepath, 'r', encoding='utf-8') as f:
430
+ content = f.read()
431
+ except Exception as e:
432
+ logger.error(f"LOAD_RULES: Error reading file {filepath}: {e}", exc_info=False)
433
+ return 0, 0, 1
434
+
435
+ if not content.strip():
436
+ logger.info(f"LOAD_RULES: File {filepath} is empty. Skipping loading.")
437
+ return 0, 0, 0
438
+
439
+ file_name_lower = filepath.lower()
440
+
441
+ if file_name_lower.endswith(".txt"):
442
+ potential_rules = content.split("\n\n---\n\n")
443
+ if len(potential_rules) == 1 and "\n" in content:
444
+ potential_rules = [r.strip() for r in content.splitlines() if r.strip()]
445
+ elif file_name_lower.endswith(".jsonl"):
446
+ for line_num, line in enumerate(content.splitlines()):
447
+ line = line.strip()
448
+ if line:
449
+ try:
450
+ rule_text_in_json_string = json.loads(line)
451
+ if isinstance(rule_text_in_json_string, str):
452
+ potential_rules.append(rule_text_in_json_string)
453
+ else:
454
+ logger.warning(f"LOAD_RULES (JSONL): Line {line_num+1} in {filepath} did not contain a string value. Got: {type(rule_text_in_json_string)}")
455
+ error_count +=1
456
+ except json.JSONDecodeError:
457
+ logger.warning(f"LOAD_RULES (JSONL): Line {line_num+1} in {filepath} failed to parse as JSON: {line[:100]}")
458
+ error_count +=1
459
+ else:
460
+ logger.error(f"LOAD_RULES: Unsupported file type for rules: {filepath}. Must be .txt or .jsonl")
461
+ return 0, 0, 1
462
+
463
+ valid_potential_rules = [r.strip() for r in potential_rules if r.strip()]
464
+ total_to_process = len(valid_potential_rules)
465
+
466
+ if total_to_process == 0 and error_count == 0:
467
+ logger.info(f"LOAD_RULES: No valid rule segments found in {filepath} to process.")
468
+ return 0, 0, 0
469
+ elif total_to_process == 0 and error_count > 0:
470
+ logger.warning(f"LOAD_RULES: No valid rule segments found to process. Encountered {error_count} parsing/format errors in {filepath}.")
471
+ return 0, 0, error_count
472
+
473
+ logger.info(f"LOAD_RULES: Attempting to add {total_to_process} potential rules from {filepath}...")
474
+ for idx, rule_text in enumerate(valid_potential_rules):
475
+ success, status_msg = add_rule_entry(rule_text)
476
+ if success:
477
+ added_count += 1
478
+ elif status_msg == "duplicate":
479
+ skipped_count += 1
480
+ else:
481
+ logger.warning(f"LOAD_RULES: Failed to add rule from {filepath} (segment {idx+1}): '{rule_text[:50]}...'. Status: {status_msg}")
482
+ error_count += 1
483
+
484
+ logger.info(f"LOAD_RULES: Finished processing {filepath}. Added: {added_count}, Skipped (duplicates): {skipped_count}, Errors: {error_count}.")
485
+ return added_count, skipped_count, error_count
486
+
487
+ def load_memories_from_file(filepath: str | None):
488
+ if not filepath:
489
+ logger.info("LOAD_MEMORIES_FILE environment variable not set. Skipping memories loading from file.")
490
+ return 0, 0, 0
491
+
492
+ if not os.path.exists(filepath):
493
+ logger.warning(f"LOAD_MEMORIES: Specified memories file not found: {filepath}. Skipping loading.")
494
+ return 0, 0, 0
495
+
496
+ added_count, format_error_count, save_error_count = 0, 0, 0
497
+ memory_objects_to_process = []
498
+
499
+ try:
500
+ with open(filepath, 'r', encoding='utf-8') as f:
501
+ content = f.read()
502
+ except Exception as e:
503
+ logger.error(f"LOAD_MEMORIES: Error reading file {filepath}: {e}", exc_info=False)
504
+ return 0, 1, 0
505
+
506
+ if not content.strip():
507
+ logger.info(f"LOAD_MEMORIES: File {filepath} is empty. Skipping loading.")
508
+ return 0, 0, 0
509
+
510
+ file_ext = os.path.splitext(filepath.lower())[1]
511
+
512
+ if file_ext == ".json":
513
+ try:
514
+ parsed_json = json.loads(content)
515
+ if isinstance(parsed_json, list):
516
+ memory_objects_to_process = parsed_json
517
+ elif isinstance(parsed_json, dict):
518
+ memory_objects_to_process = [parsed_json]
519
+ else:
520
+ logger.warning(f"LOAD_MEMORIES (.json): File content is not a JSON list or object in {filepath}. Type: {type(parsed_json)}")
521
+ format_error_count = 1
522
+ except json.JSONDecodeError as e:
523
+ logger.warning(f"LOAD_MEMORIES (.json): Invalid JSON file {filepath}. Error: {e}")
524
+ format_error_count = 1
525
+ elif file_ext == ".jsonl":
526
+ for line_num, line in enumerate(content.splitlines()):
527
+ line = line.strip()
528
+ if line:
529
+ try:
530
+ memory_objects_to_process.append(json.loads(line))
531
+ except json.JSONDecodeError:
532
+ logger.warning(f"LOAD_MEMORIES (.jsonl): Line {line_num+1} in {filepath} parse error: {line[:100]}")
533
+ format_error_count += 1
534
+ else:
535
+ logger.error(f"LOAD_MEMORIES: Unsupported file type for memories: {filepath}. Must be .json or .jsonl")
536
+ return 0, 1, 0
537
+
538
+ total_to_process = len(memory_objects_to_process)
539
+
540
+ if total_to_process == 0 and format_error_count > 0 :
541
+ logger.warning(f"LOAD_MEMORIES: File parsing failed for {filepath}. Found {format_error_count} format errors and no processable objects.")
542
+ return 0, format_error_count, 0
543
+ elif total_to_process == 0:
544
+ logger.info(f"LOAD_MEMORIES: No memory objects found in {filepath} after parsing.")
545
+ return 0, 0, 0
546
+
547
+ logger.info(f"LOAD_MEMORIES: Attempting to add {total_to_process} memory objects from {filepath}...")
548
+ for idx, mem_data in enumerate(memory_objects_to_process):
549
+ if isinstance(mem_data, dict) and all(k in mem_data for k in ["user_input", "bot_response", "metrics"]):
550
+ success, _ = add_memory_entry(mem_data["user_input"], mem_data["metrics"], mem_data["bot_response"])
551
+ if success:
552
+ added_count += 1
553
+ else:
554
+ logger.warning(f"LOAD_MEMORIES: Failed to save memory object from {filepath} (segment {idx+1}). Data: {str(mem_data)[:100]}")
555
+ save_error_count += 1
556
+ else:
557
+ logger.warning(f"LOAD_MEMORIES: Skipped invalid memory object structure in {filepath} (segment {idx+1}): {str(mem_data)[:100]}")
558
+ format_error_count += 1
559
+
560
+ logger.info(f"LOAD_MEMORIES: Finished processing {filepath}. Added: {added_count}, Format/Structure Errors: {format_error_count}, Save Errors: {save_error_count}.")
561
+ return added_count, format_error_count, save_error_count
562
+
563
+
564
+ def process_rules_from_text_blob(rules_text: str, progress_callback=None) -> dict:
565
+ if not rules_text.strip():
566
+ return {"added": 0, "skipped": 0, "errors": 0, "total": 0}
567
+
568
+ potential_rules = rules_text.split("\n\n---\n\n")
569
+ if len(potential_rules) == 1 and "\n" in rules_text:
570
+ potential_rules = [r.strip() for r in rules_text.splitlines() if r.strip()]
571
+
572
+ unique_rules = sorted(list(set(filter(None, [r.strip() for r in potential_rules]))))
573
+ total_unique = len(unique_rules)
574
+ if total_unique == 0:
575
+ return {"added": 0, "skipped": 0, "errors": 0, "total": 0}
576
+
577
+ stats = {"added": 0, "skipped": 0, "errors": 0, "total": total_unique}
578
+ for idx, rule_text in enumerate(unique_rules):
579
+ success, status_msg = add_rule_entry(rule_text)
580
+ if success:
581
+ stats["added"] += 1
582
+ elif status_msg == "duplicate":
583
+ stats["skipped"] += 1
584
+ else:
585
+ stats["errors"] += 1
586
+
587
+ if progress_callback is not None:
588
+ progress_callback((idx + 1) / total_unique, desc=f"Processed {idx+1}/{total_unique} rules...")
589
+
590
+ return stats
591
+
592
+
593
+ def import_kb_from_kv_dict(kv_dict: dict, progress_callback=None) -> dict:
594
+ rules_to_add, memories_to_add = [], []
595
+ for key, value in kv_dict.items():
596
+ if key.startswith("rule_"):
597
+ try:
598
+ rules_to_add.append(json.loads(value))
599
+ except:
600
+ logger.warning(f"KB Dict Import: Bad rule format for key {key}")
601
+ elif key.startswith("memory_"):
602
+ try:
603
+ mem_dict = json.loads(value)
604
+ if isinstance(mem_dict, dict) and all(k in mem_dict for k in ['user_input', 'bot_response', 'metrics']):
605
+ memories_to_add.append(mem_dict)
606
+ except:
607
+ logger.warning(f"KB Dict Import: Bad memory format for key {key}")
608
+
609
+ stats = {"rules_added": 0, "rules_skipped": 0, "rules_errors": 0, "mems_added": 0, "mems_errors": 0}
610
+ total_items = len(rules_to_add) + len(memories_to_add)
611
+ processed_items = 0
612
+
613
+ if progress_callback is not None:
614
+ progress_callback(0, desc=f"Importing {total_items} items...")
615
+
616
+ for rule in rules_to_add:
617
+ s, m = add_rule_entry(rule)
618
+ if s:
619
+ stats["rules_added"] += 1
620
+ elif m == "duplicate":
621
+ stats["rules_skipped"] += 1
622
+ else:
623
+ stats["rules_errors"] += 1
624
+ processed_items += 1
625
+ if progress_callback is not None and total_items > 0:
626
+ progress_callback(processed_items / total_items, desc=f"Processing item {processed_items}/{total_items}...")
627
+
628
+ for mem in memories_to_add:
629
+ s, _ = add_memory_entry(mem['user_input'], mem['metrics'], mem['bot_response'])
630
+ if s:
631
+ stats["mems_added"] += 1
632
+ else:
633
+ stats["mems_errors"] += 1
634
+ processed_items += 1
635
+ if progress_callback is not None and total_items > 0:
636
+ progress_callback(processed_items / total_items, desc=f"Processing item {processed_items}/{total_items}...")
637
+
638
+ return stats
639
+
640
+
641
  FAISS_MEMORY_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "memory_index.faiss")
642
  FAISS_RULES_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "rules_index.faiss")
643
 
 
663
  global _faiss_memory_index, _faiss_rules_index
664
  if not _initialized or not faiss: return
665
 
666
+ if os.path.exists(FAISS_MEMORY_PATH) and _faiss_memory_index:
667
  try:
668
  logger.info(f"Loading memory FAISS index from {FAISS_MEMORY_PATH}...")
669
  _faiss_memory_index = faiss.read_index(FAISS_MEMORY_PATH)
670
  logger.info(f"Memory FAISS index loaded ({_faiss_memory_index.ntotal} items).")
 
671
  if _faiss_memory_index.ntotal != len(_memory_items_list) and len(_memory_items_list) > 0:
672
  logger.warning(f"Memory FAISS index count ({_faiss_memory_index.ntotal}) differs from loaded texts ({len(_memory_items_list)}). Consider rebuilding FAISS.")
673
  except Exception as e: logger.error(f"Error loading memory FAISS index: {e}. Will use fresh index.")
 
679
  logger.info(f"Rules FAISS index loaded ({_faiss_rules_index.ntotal} items).")
680
  if _faiss_rules_index.ntotal != len(_rules_items_list) and len(_rules_items_list) > 0:
681
  logger.warning(f"Rules FAISS index count ({_faiss_rules_index.ntotal}) differs from loaded texts ({len(_rules_items_list)}). Consider rebuilding FAISS.")
682
+ except Exception as e: logger.error(f"Error loading rules FAISS index: {e}. Will use fresh index.")