broadfield-dev commited on
Commit
6c9ab1e
·
verified ·
1 Parent(s): 60122df

Create memory_logic.py

Browse files
Files changed (1) hide show
  1. memory_logic.py +381 -0
memory_logic.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # memory_logic.py
2
+ import os
3
+ import json
4
+ import time
5
+ from datetime import datetime
6
+ import logging
7
+ import re
8
+ import threading
9
+
10
+ # Conditionally import heavy dependencies
11
+ try:
12
+ from sentence_transformers import SentenceTransformer
13
+ import faiss
14
+ import numpy as np
15
+ except ImportError:
16
+ SentenceTransformer, faiss, np = None, None, None
17
+ logging.warning("SentenceTransformers, FAISS, or NumPy not installed. Semantic search will be unavailable.")
18
+
19
+ try:
20
+ import sqlite3
21
+ except ImportError:
22
+ sqlite3 = None
23
+ logging.warning("sqlite3 module not available. SQLite backend will be unavailable.")
24
+
25
+ try:
26
+ from datasets import load_dataset, Dataset
27
+ except ImportError:
28
+ load_dataset, Dataset = None, None
29
+ logging.warning("datasets library not installed. Hugging Face Dataset backend will be unavailable.")
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+ # Suppress verbose logs from dependencies
34
+ for lib_name in ["sentence_transformers", "faiss", "datasets", "huggingface_hub"]:
35
+ if logging.getLogger(lib_name): # Check if logger exists
36
+ logging.getLogger(lib_name).setLevel(logging.WARNING)
37
+
38
+
39
+ # --- Configuration (Read directly from environment variables) ---
40
+ STORAGE_BACKEND = os.getenv("STORAGE_BACKEND", "HF_DATASET").upper() #HF_DATASET, RAM, SQLITE
41
+ SQLITE_DB_PATH = os.getenv("SQLITE_DB_PATH", "app_data/ai_memory.db") # Changed default path
42
+ HF_TOKEN = os.getenv("HF_TOKEN")
43
+ HF_MEMORY_DATASET_REPO = os.getenv("HF_MEMORY_DATASET_REPO", "broadfield-dev/ai-brain") # Example
44
+ HF_RULES_DATASET_REPO = os.getenv("HF_RULES_DATASET_REPO", "broadfield-dev/ai-rules") # Example
45
+
46
+ # --- Globals for RAG within this module ---
47
+ _embedder = None
48
+ _dimension = 384 # Default, will be set by embedder
49
+ _faiss_memory_index = None
50
+ _memory_items_list = [] # Stores JSON strings of memory objects for RAM, or loaded from DB/HF
51
+ _faiss_rules_index = None
52
+ _rules_items_list = [] # Stores rule text strings
53
+
54
+ _initialized = False
55
+ _init_lock = threading.Lock()
56
+
57
+ # --- Helper: SQLite Connection ---
58
+ def _get_sqlite_connection():
59
+ if not sqlite3:
60
+ raise ImportError("sqlite3 module is required for SQLite backend but not found.")
61
+ db_dir = os.path.dirname(SQLITE_DB_PATH)
62
+ if db_dir and not os.path.exists(db_dir):
63
+ os.makedirs(db_dir, exist_ok=True)
64
+ return sqlite3.connect(SQLITE_DB_PATH, timeout=10)
65
+
66
+ def _init_sqlite_tables():
67
+ if STORAGE_BACKEND != "SQLITE" or not sqlite3:
68
+ return
69
+ try:
70
+ with _get_sqlite_connection() as conn:
71
+ cursor = conn.cursor()
72
+ cursor.execute("""
73
+ CREATE TABLE IF NOT EXISTS memories (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ memory_json TEXT NOT NULL,
76
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
77
+ )
78
+ """)
79
+ cursor.execute("""
80
+ CREATE TABLE IF NOT EXISTS rules (
81
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
82
+ rule_text TEXT NOT NULL UNIQUE,
83
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
84
+ )
85
+ """)
86
+ conn.commit()
87
+ logger.info("SQLite tables for memories and rules checked/created.")
88
+ except Exception as e:
89
+ logger.error(f"SQLite table initialization error: {e}", exc_info=True)
90
+
91
+
92
+ def _build_faiss_index(items_list, text_extraction_fn):
93
+ """Builds a FAISS index from a list of items."""
94
+ if not _embedder:
95
+ logger.error("Cannot build FAISS index: Embedder not available.")
96
+ return None
97
+
98
+ index = faiss.IndexFlatL2(_dimension)
99
+ if not items_list:
100
+ return index
101
+
102
+ logger.info(f"Building FAISS index for {len(items_list)} items...")
103
+ texts_to_embed = [text_extraction_fn(item) for item in items_list]
104
+
105
+ try:
106
+ embeddings = _embedder.encode(texts_to_embed, convert_to_tensor=False, show_progress_bar=False)
107
+ embeddings_np = np.array(embeddings, dtype=np.float32)
108
+ if embeddings_np.ndim == 2 and embeddings_np.shape[0] == len(items_list):
109
+ index.add(embeddings_np)
110
+ logger.info(f"FAISS index built successfully with {index.ntotal} items.")
111
+ else:
112
+ logger.error(f"FAISS build failed: Embeddings shape error. Expected ({len(items_list)}, {_dimension}), Got {getattr(embeddings_np, 'shape', 'N/A')}")
113
+ return faiss.IndexFlatL2(_dimension) # Return empty index on failure
114
+ except Exception as e:
115
+ logger.error(f"Error building FAISS index: {e}", exc_info=True)
116
+ return faiss.IndexFlatL2(_dimension) # Return empty index on failure
117
+
118
+ return index
119
+
120
+ # --- Initialization ---
121
+ def initialize_memory_system():
122
+ global _initialized, _embedder, _dimension, _faiss_memory_index, _memory_items_list, _faiss_rules_index, _rules_items_list
123
+
124
+ with _init_lock:
125
+ if _initialized:
126
+ return
127
+
128
+ logger.info(f"Initializing memory system with backend: {STORAGE_BACKEND}")
129
+ init_start_time = time.time()
130
+
131
+ if not SentenceTransformer or not faiss or not np:
132
+ logger.error("Core RAG libraries not available. Cannot initialize semantic memory.")
133
+ return
134
+
135
+ if not _embedder:
136
+ try:
137
+ logger.info("Loading SentenceTransformer model (all-MiniLM-L6-v2)...")
138
+ _embedder = SentenceTransformer('all-MiniLM-L6-v2', cache_folder="./sentence_transformer_cache")
139
+ _dimension = _embedder.get_sentence_embedding_dimension() or 384
140
+ except Exception as e:
141
+ logger.critical(f"FATAL: Error loading SentenceTransformer: {e}", exc_info=True)
142
+ return
143
+
144
+ if STORAGE_BACKEND == "SQLITE": _init_sqlite_tables()
145
+
146
+ # Load Memories from persistent storage
147
+ temp_memories_json = []
148
+ if STORAGE_BACKEND == "SQLITE":
149
+ try: temp_memories_json = [row[0] for row in _get_sqlite_connection().execute("SELECT memory_json FROM memories")]
150
+ except Exception as e: logger.error(f"Error loading memories from SQLite: {e}")
151
+ elif STORAGE_BACKEND == "HF_DATASET":
152
+ try:
153
+ logger.info(f"Loading memories from HF Dataset: {HF_MEMORY_DATASET_REPO}")
154
+ dataset = load_dataset(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
155
+ if "train" in dataset and "memory_json" in dataset["train"].column_names:
156
+ temp_memories_json = [m for m in dataset["train"]["memory_json"] if isinstance(m, str) and m.strip()]
157
+ logger.info(f"Loaded {len(temp_memories_json)} valid memories from HF Dataset.")
158
+ else: logger.warning(f"HF Dataset {HF_MEMORY_DATASET_REPO} has no 'train' split or 'memory_json' column.")
159
+ except Exception as e: logger.error(f"Error loading memories from HF Dataset: {e}", exc_info=True)
160
+
161
+ _memory_items_list = temp_memories_json
162
+
163
+ # Build Memory FAISS Index
164
+ _faiss_memory_index = _build_faiss_index(
165
+ _memory_items_list,
166
+ lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}\nTakeaway: {json.loads(m).get('metrics', {}).get('takeaway', 'N/A')}"
167
+ )
168
+
169
+ # Load Rules from persistent storage
170
+ temp_rules_text = []
171
+ if STORAGE_BACKEND == "SQLITE":
172
+ try: temp_rules_text = [row[0] for row in _get_sqlite_connection().execute("SELECT rule_text FROM rules")]
173
+ except Exception as e: logger.error(f"Error loading rules from SQLite: {e}")
174
+ elif STORAGE_BACKEND == "HF_DATASET":
175
+ try:
176
+ logger.info(f"Loading rules from HF Dataset: {HF_RULES_DATASET_REPO}")
177
+ dataset = load_dataset(HF_RULES_DATASET_REPO, token=HF_TOKEN, trust_remote_code=True)
178
+ if "train" in dataset and "rule_text" in dataset["train"].column_names:
179
+ temp_rules_text = [r for r in dataset["train"]["rule_text"] if isinstance(r, str) and r.strip()]
180
+ logger.info(f"Loaded {len(temp_rules_text)} valid rules from HF Dataset.")
181
+ else: logger.warning(f"HF Dataset {HF_RULES_DATASET_REPO} has no 'train' split or 'rule_text' column.")
182
+ except Exception as e: logger.error(f"Error loading rules from HF Dataset: {e}", exc_info=True)
183
+
184
+ _rules_items_list = sorted(list(set(temp_rules_text)))
185
+
186
+ # Build Rules FAISS Index
187
+ _faiss_rules_index = _build_faiss_index(_rules_items_list, lambda r: r)
188
+
189
+ _initialized = True
190
+ logger.info(f"Memory system initialization complete in {time.time() - init_start_time:.2f}s")
191
+
192
+
193
+ def _verify_and_rebuild_if_needed(index, items_list, text_extraction_fn):
194
+ """Self-healing function to ensure FAISS index is synced with the item list."""
195
+ if not index or index.ntotal != len(items_list):
196
+ logger.warning(
197
+ f"FAISS index mismatch detected (Index: {index.ntotal if index else 'None'}, List: {len(items_list)}). "
198
+ "Rebuilding index from in-memory cache."
199
+ )
200
+ return _build_faiss_index(items_list, text_extraction_fn)
201
+ return index
202
+
203
+
204
+ # --- Memory Operations (Semantic) ---
205
+ def add_memory_entry(user_input: str, metrics: dict, bot_response: str) -> tuple[bool, str]:
206
+ global _memory_items_list, _faiss_memory_index
207
+ if not _initialized: initialize_memory_system()
208
+ if not _embedder or not _faiss_memory_index:
209
+ return False, "Memory system not ready for adding entries."
210
+
211
+ memory_obj = {"user_input": user_input, "metrics": metrics, "bot_response": bot_response, "timestamp": datetime.utcnow().isoformat()}
212
+ memory_json_str = json.dumps(memory_obj)
213
+ text_to_embed = f"User: {user_input}\nAI: {bot_response}\nTakeaway: {metrics.get('takeaway', 'N/A')}"
214
+
215
+ try:
216
+ embedding = _embedder.encode([text_to_embed], convert_to_tensor=False)
217
+ embedding_np = np.array(embedding, dtype=np.float32)
218
+
219
+ _faiss_memory_index.add(embedding_np)
220
+ _memory_items_list.append(memory_json_str)
221
+
222
+ if STORAGE_BACKEND == "SQLITE":
223
+ with _get_sqlite_connection() as conn:
224
+ conn.execute("INSERT INTO memories (memory_json) VALUES (?)", (memory_json_str,)); conn.commit()
225
+ elif STORAGE_BACKEND == "HF_DATASET":
226
+ Dataset.from_dict({"memory_json": list(_memory_items_list)}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
227
+
228
+ logger.info(f"Added memory. Cache size: {len(_memory_items_list)}, FAISS size: {_faiss_memory_index.ntotal}")
229
+ return True, "Memory added successfully."
230
+ except Exception as e:
231
+ logger.error(f"Error adding memory entry: {e}", exc_info=True)
232
+ return False, f"Error adding memory: {e}"
233
+
234
+ def retrieve_memories_semantic(query: str, k: int = 3) -> list[dict]:
235
+ global _faiss_memory_index
236
+ if not _initialized: initialize_memory_system()
237
+
238
+ # Self-healing: Verify index is synced with cache, rebuild if not.
239
+ _faiss_memory_index = _verify_and_rebuild_if_needed(
240
+ _faiss_memory_index, _memory_items_list,
241
+ lambda m: f"User: {json.loads(m).get('user_input', '')}\nAI: {json.loads(m).get('bot_response', '')}\nTakeaway: {json.loads(m).get('metrics', {}).get('takeaway', 'N/A')}"
242
+ )
243
+
244
+ if not _faiss_memory_index or _faiss_memory_index.ntotal == 0:
245
+ logger.debug("Cannot retrieve memories: index is empty.")
246
+ return []
247
+
248
+ try:
249
+ query_embedding = _embedder.encode([query], convert_to_tensor=False)
250
+ query_embedding_np = np.array(query_embedding, dtype=np.float32)
251
+ distances, indices = _faiss_memory_index.search(query_embedding_np, min(k, _faiss_memory_index.ntotal))
252
+
253
+ results = [json.loads(_memory_items_list[i]) for i in indices[0] if 0 <= i < len(_memory_items_list)]
254
+ logger.info(f"Retrieved {len(results)} memories for query: '{query[:50]}...'")
255
+ return results
256
+ except Exception as e:
257
+ logger.error(f"Error retrieving memories semantically: {e}", exc_info=True)
258
+ return []
259
+
260
+
261
+ # --- Rule (Insight) Operations (Semantic) ---
262
+ def add_rule_entry(rule_text: str) -> tuple[bool, str]:
263
+ global _rules_items_list, _faiss_rules_index
264
+ if not _initialized: initialize_memory_system()
265
+
266
+ rule_text = rule_text.strip()
267
+ if not rule_text or "duplicate" == rule_text or rule_text in _rules_items_list:
268
+ return False, "duplicate or invalid"
269
+ if not re.match(r"\[(CORE_RULE|RESPONSE_PRINCIPLE|BEHAVIORAL_ADJUSTMENT|GENERAL_LEARNING)\|([\d\.]+?)\]", rule_text, re.I):
270
+ return False, "Invalid rule format."
271
+
272
+ try:
273
+ embedding = _embedder.encode([rule_text], convert_to_tensor=False)
274
+ embedding_np = np.array(embedding, dtype=np.float32)
275
+ _faiss_rules_index.add(embedding_np)
276
+ _rules_items_list.append(rule_text)
277
+ _rules_items_list.sort()
278
+
279
+ if STORAGE_BACKEND == "SQLITE":
280
+ with _get_sqlite_connection() as conn:
281
+ conn.execute("INSERT OR IGNORE INTO rules (rule_text) VALUES (?)", (rule_text,)); conn.commit()
282
+ elif STORAGE_BACKEND == "HF_DATASET":
283
+ Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
284
+ return True, "Rule added successfully."
285
+ except Exception as e:
286
+ logger.error(f"Error adding rule entry: {e}", exc_info=True)
287
+ return False, f"Error adding rule: {e}"
288
+
289
+ def retrieve_rules_semantic(query: str, k: int = 5) -> list[str]:
290
+ global _faiss_rules_index
291
+ if not _initialized: initialize_memory_system()
292
+
293
+ _faiss_rules_index = _verify_and_rebuild_if_needed(_faiss_rules_index, _rules_items_list, lambda r: r)
294
+
295
+ if not _faiss_rules_index or _faiss_rules_index.ntotal == 0: return []
296
+ try:
297
+ query_embedding = _embedder.encode([query], convert_to_tensor=False)
298
+ query_embedding_np = np.array(query_embedding, dtype=np.float32)
299
+ distances, indices = _faiss_rules_index.search(query_embedding_np, min(k, _faiss_rules_index.ntotal))
300
+ return [_rules_items_list[i] for i in indices[0] if 0 <= i < len(_rules_items_list)]
301
+ except Exception as e:
302
+ logger.error(f"Error retrieving rules semantically: {e}", exc_info=True)
303
+ return []
304
+
305
+ def remove_rule_entry(rule_text_to_delete: str) -> bool:
306
+ global _rules_items_list, _faiss_rules_index
307
+ if not _initialized: initialize_memory_system()
308
+ rule_text_to_delete = rule_text_to_delete.strip()
309
+ if rule_text_to_delete not in _rules_items_list: return False
310
+ try:
311
+ _rules_items_list.remove(rule_text_to_delete)
312
+ _faiss_rules_index = _build_faiss_index(_rules_items_list, lambda r: r)
313
+
314
+ if STORAGE_BACKEND == "SQLITE":
315
+ with _get_sqlite_connection() as conn:
316
+ conn.execute("DELETE FROM rules WHERE rule_text = ?", (rule_text_to_delete,)); conn.commit()
317
+ elif STORAGE_BACKEND == "HF_DATASET":
318
+ Dataset.from_dict({"rule_text": list(_rules_items_list)}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
319
+ return True
320
+ except Exception as e:
321
+ logger.error(f"Error removing rule entry: {e}", exc_info=True)
322
+ return False
323
+
324
+ # --- Utility functions to get all data (for UI display, etc.) ---
325
+ def get_all_rules_cached() -> list[str]:
326
+ if not _initialized: initialize_memory_system()
327
+ return list(_rules_items_list)
328
+
329
+ def get_all_memories_cached() -> list[dict]:
330
+ if not _initialized: initialize_memory_system()
331
+ return [json.loads(m) for m in _memory_items_list if m]
332
+
333
+ def clear_all_memory_data_backend() -> bool:
334
+ global _memory_items_list, _faiss_memory_index
335
+ if not _initialized: initialize_memory_system()
336
+ _memory_items_list = []
337
+ if _faiss_memory_index: _faiss_memory_index.reset()
338
+ try:
339
+ if STORAGE_BACKEND == "SQLITE":
340
+ with _get_sqlite_connection() as conn: conn.execute("DELETE FROM memories"); conn.commit()
341
+ elif STORAGE_BACKEND == "HF_DATASET":
342
+ Dataset.from_dict({"memory_json": []}).push_to_hub(HF_MEMORY_DATASET_REPO, token=HF_TOKEN, private=True)
343
+ logger.info("All memories cleared.")
344
+ return True
345
+ except Exception as e:
346
+ logger.error(f"Error clearing all memory data: {e}"); return False
347
+
348
+ def clear_all_rules_data_backend() -> bool:
349
+ global _rules_items_list, _faiss_rules_index
350
+ if not _initialized: initialize_memory_system()
351
+ _rules_items_list = []
352
+ if _faiss_rules_index: _faiss_rules_index.reset()
353
+ try:
354
+ if STORAGE_BACKEND == "SQLITE":
355
+ with _get_sqlite_connection() as conn: conn.execute("DELETE FROM rules"); conn.commit()
356
+ elif STORAGE_BACKEND == "HF_DATASET":
357
+ Dataset.from_dict({"rule_text": []}).push_to_hub(HF_RULES_DATASET_REPO, token=HF_TOKEN, private=True)
358
+ logger.info("All rules cleared.")
359
+ return True
360
+ except Exception as e:
361
+ logger.error(f"Error clearing all rules data: {e}"); return False
362
+
363
+ FAISS_MEMORY_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "memory_index.faiss")
364
+ FAISS_RULES_PATH = os.path.join(os.getenv("FAISS_STORAGE_PATH", "app_data/faiss_indices"), "rules_index.faiss")
365
+
366
+ def save_faiss_indices_to_disk():
367
+ if not _initialized or not faiss: return
368
+ faiss_dir = os.path.dirname(FAISS_MEMORY_PATH)
369
+ if not os.path.exists(faiss_dir): os.makedirs(faiss_dir, exist_ok=True)
370
+ if _faiss_memory_index and _faiss_memory_index.ntotal > 0:
371
+ faiss.write_index(_faiss_memory_index, FAISS_MEMORY_PATH)
372
+ if _faiss_rules_index and _faiss_rules_index.ntotal > 0:
373
+ faiss.write_index(_faiss_rules_index, FAISS_RULES_PATH)
374
+
375
+ def load_faiss_indices_from_disk():
376
+ global _faiss_memory_index, _faiss_rules_index
377
+ if not _initialized or not faiss: return
378
+ if os.path.exists(FAISS_MEMORY_PATH):
379
+ _faiss_memory_index = faiss.read_index(FAISS_MEMORY_PATH)
380
+ if os.path.exists(FAISS_RULES_PATH):
381
+ _faiss_rules_index = faiss.read_index(FAISS_RULES_PATH)