aamirhameed commited on
Commit
5d56f39
·
verified ·
1 Parent(s): 48a324e

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +99 -302
knowledge_engine.py CHANGED
@@ -1,315 +1,112 @@
1
  import os
2
- import tempfile
3
- import shutil
4
- from typing import Dict, List
5
- from datetime import datetime
6
- from concurrent.futures import ThreadPoolExecutor
7
-
8
- from langchain_core.documents import Document
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain_community.vectorstores import FAISS
11
- from langchain.retrievers import BM25Retriever
12
- from langchain_community.embeddings import HuggingFaceEmbeddings
13
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
14
 
 
 
 
15
 
16
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, pipeline
17
  from langchain.llms import HuggingFacePipeline
 
 
 
 
 
18
 
19
- class CPULLMProvider:
20
- """CPU-based LLM provider using HuggingFace models"""
21
-
22
- def __init__(self):
23
- self.name = "CPU-LLM"
24
- self.is_available = False
25
- self.current_model = None
26
-
27
- # CPU-friendly models
28
- self.cpu_models = [
29
- "google/flan-t5-small", # Encoder-decoder model
30
- "distilbert/distilgpt2" # Decoder-only (GPT-style)
31
- ]
32
-
33
- def initialize(self) -> bool:
34
- """Initialize the CPU LLM with the best available model"""
35
- for model_id in self.cpu_models:
36
- try:
37
- print(f"[i] Trying to load {model_id}...")
38
-
39
- tokenizer = AutoTokenizer.from_pretrained(model_id)
40
-
41
- # Detect model type based on name
42
- if "flan" in model_id or "t5" in model_id:
43
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
44
- task = "text2text-generation"
45
- else:
46
- model = AutoModelForCausalLM.from_pretrained(model_id)
47
- task = "text-generation"
48
-
49
- pipe = pipeline(
50
- task,
51
- model=model,
52
- tokenizer=tokenizer,
53
- max_new_tokens=256,
54
- temperature=0.3,
55
- top_p=0.95,
56
- device="cpu"
57
- )
58
-
59
- self.llm = HuggingFacePipeline(pipeline=pipe)
60
- self.current_model = model_id
61
- self.is_available = True
62
-
63
- # Test model
64
- test_response = self.invoke("Hello, who are you?")
65
- if test_response and len(test_response) > 0:
66
- print(f"[✓] Successfully loaded {model_id}")
67
- return True
68
-
69
- except Exception as e:
70
- print(f"[!] Failed to load {model_id}: {str(e)[:200]}...")
71
- continue
72
-
73
- print("[!] All CPU models failed to load")
74
- return False
75
-
76
- def invoke(self, prompt: str) -> str:
77
- """Invoke the CPU model with prompt"""
78
- if not self.llm:
79
- raise Exception("CPU LLM not initialized")
80
-
81
- try:
82
- # Optionally modify prompt for specific models if needed
83
- formatted_prompt = prompt
84
- response = self.llm.invoke(formatted_prompt)
85
- return response.strip()
86
- except Exception as e:
87
- print(f"[!] CPU model error: {e}")
88
- raise
89
-
90
-
91
 
92
  class KnowledgeManager:
93
- def __init__(self):
94
- self.temp_dir = tempfile.mkdtemp()
95
- self.setup_dirs()
96
- self.embeddings = self.init_embeddings()
97
- self.vector_db = None
98
- self.bm25_retriever = None
99
- self.llm_provider = CPULLMProvider()
100
- self.knowledge_texts = []
101
-
102
- self.init_system()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- def setup_dirs(self):
105
- """Setup temporary directories"""
106
- self.knowledge_dir = os.path.join(self.temp_dir, "knowledge")
107
- os.makedirs(self.knowledge_dir, exist_ok=True)
108
 
109
- def init_embeddings(self):
110
- """Initialize CPU-friendly embeddings"""
111
  try:
112
- return HuggingFaceEmbeddings(
113
- model_name="sentence-transformers/all-MiniLM-L6-v2",
114
- model_kwargs={'device': 'cpu'},
115
- encode_kwargs={'normalize_embeddings': True}
 
 
116
  )
 
117
  except Exception as e:
118
- print(f"[!] Failed to load embeddings: {e}")
119
- return None
120
-
121
- def init_system(self):
122
- """Initialize the RAG system"""
123
- print("[i] Initializing CPU LLM...")
124
- if self.llm_provider.initialize():
125
- print(f"[✓] Using model: {self.llm_provider.current_model}")
126
- else:
127
- print("[!] Continuing without LLM (retrieval only)")
128
-
129
- # Load default knowledge
130
- self._load_default_knowledge()
131
-
132
- # Build retrievers
133
- self.build_retrievers()
134
-
135
- def _load_default_knowledge(self):
136
- """Load default knowledge base"""
137
- default_content = """Sirraya xBrain - CPU-based AI Platform
138
-
139
- Features:
140
- - Uses efficient CPU-based language models like Phi-2
141
- - Implements RAG (Retrieval-Augmented Generation)
142
- - Combines vector search and keyword retrieval
143
- - Optimized for CPU-only environments
144
-
145
- Technical Details:
146
- - Embeddings: all-MiniLM-L6-v2
147
- - Vector Store: FAISS
148
- - Keyword Retrieval: BM25
149
- - LLM: Microsoft Phi-2 or similar CPU-friendly models"""
150
-
151
- self.knowledge_texts = [{
152
- "filename": "default_knowledge.txt",
153
- "content": default_content
154
- }]
155
-
156
- # Save to file
157
- with open(os.path.join(self.knowledge_dir, "default_knowledge.txt"), "w") as f:
158
- f.write(default_content)
159
-
160
- def build_retrievers(self):
161
- """Build the retrieval components"""
162
- if not self.embeddings:
163
- print("[!] No embeddings available")
164
- return
165
-
166
- try:
167
- # Create documents
168
- documents = [
169
- Document(
170
- page_content=text["content"],
171
- metadata={"source": text["filename"]}
172
- )
173
- for text in self.knowledge_texts
174
- ]
175
-
176
- # Split documents
177
- splitter = RecursiveCharacterTextSplitter(
178
- chunk_size=512,
179
- chunk_overlap=128,
180
- separators=["\n\n", "\n", ". ", "! ", "? ", "; ", " ", ""]
181
  )
182
- chunks = splitter.split_documents(documents)
183
-
184
- # Create vector store
185
- self.vector_db = FAISS.from_documents(
186
- chunks,
187
- self.embeddings,
188
- distance_strategy="COSINE"
189
- )
190
-
191
- # Create BM25 retriever
192
- self.bm25_retriever = BM25Retriever.from_documents(chunks)
193
- self.bm25_retriever.k = 3
194
-
195
- print(f"[✓] Built retrievers with {len(chunks)} chunks")
196
-
197
- except Exception as e:
198
- print(f"[!] Error building retrievers: {e}")
199
-
200
- def retrieve_documents(self, query: str) -> List[Document]:
201
- """Retrieve relevant documents using both methods"""
202
- if not self.vector_db or not self.bm25_retriever:
203
- return []
204
-
205
- def vector_search():
206
- try:
207
- return self.vector_db.similarity_search(query, k=2)
208
- except:
209
- return []
210
-
211
- def bm25_search():
212
- try:
213
- return self.bm25_retriever.invoke(query)
214
- except:
215
- return []
216
-
217
- with ThreadPoolExecutor(max_workers=2) as executor:
218
- vector_future = executor.submit(vector_search)
219
- bm25_future = executor.submit(bm25_search)
220
- vector_results = vector_future.result()
221
- bm25_results = bm25_future.result()
222
-
223
- # Combine and deduplicate
224
- combined = vector_results + bm25_results
225
- unique_docs = []
226
- seen = set()
227
-
228
- for doc in combined:
229
- content_hash = hash(doc.page_content)
230
- if content_hash not in seen:
231
- seen.add(content_hash)
232
- unique_docs.append(doc)
233
-
234
- return unique_docs[:3] # Return top 3 unique docs
235
-
236
- def query(self, query: str) -> Dict[str, any]:
237
- """Process a query with RAG"""
238
- start_time = datetime.now()
239
-
240
- # Retrieve relevant documents
241
- docs = self.retrieve_documents(query)
242
-
243
- if not docs:
244
- return {
245
- "answer": "No relevant information found.",
246
- "sources": [],
247
- "model": "none",
248
- "time_ms": 0
249
- }
250
-
251
- # Prepare context
252
- context = "\n\n".join([doc.page_content for doc in docs])
253
-
254
- # Generate answer if LLM is available
255
- if self.llm_provider.is_available:
256
- try:
257
- prompt = f"""Use the following context to answer the question:
258
-
259
- Context:
260
- {context}
261
-
262
- Question: {query}
263
-
264
- Answer:"""
265
-
266
- answer = self.llm_provider.invoke(prompt)
267
-
268
- return {
269
- "answer": answer,
270
- "sources": [doc.metadata.get("source", "") for doc in docs],
271
- "model": self.llm_provider.current_model,
272
- "time_ms": (datetime.now() - start_time).total_seconds() * 1000
273
- }
274
- except Exception as e:
275
- print(f"[!] LLM error: {e}")
276
- # Fall through to retrieval mode
277
-
278
- # Fallback: return best matching document
279
- best_doc = docs[0].page_content[:500] + "..." if len(docs[0].page_content) > 500 else docs[0].page_content
280
- return {
281
- "answer": f"Relevant information:\n\n{best_doc}",
282
- "sources": [doc.metadata.get("source", "") for doc in docs],
283
- "model": "retrieval-only",
284
- "time_ms": (datetime.now() - start_time).total_seconds() * 1000
285
- }
286
-
287
- def add_document(self, filename: str, content: str) -> bool:
288
- """Add a document to the knowledge base"""
289
- try:
290
- self.knowledge_texts.append({
291
- "filename": filename,
292
- "content": content
293
- })
294
-
295
- # Save to file
296
- with open(os.path.join(self.knowledge_dir, filename), "w") as f:
297
- f.write(content)
298
-
299
- # Rebuild retrievers
300
- self.build_retrievers()
301
- return True
302
-
303
- except Exception as e:
304
- print(f"[!] Error adding document: {e}")
305
- return False
306
-
307
- def cleanup(self):
308
- """Clean up temporary files"""
309
- try:
310
- shutil.rmtree(self.temp_dir)
311
- except:
312
- pass
313
-
314
- def __del__(self):
315
- self.cleanup()
 
1
  import os
2
+ from pathlib import Path
3
+ from typing import List, Optional
 
 
 
 
 
 
 
 
 
 
4
 
5
+ import faiss
6
+ import numpy as np
7
+ from sentence_transformers import SentenceTransformer
8
 
 
9
  from langchain.llms import HuggingFacePipeline
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.vectorstores.faiss import FAISS
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.document_loaders import TextLoader
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
 
16
+ import torch
17
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class KnowledgeManager:
20
+ def __init__(self, knowledge_dir="knowledge_base"):
21
+ self.knowledge_dir = Path(knowledge_dir)
22
+ self.knowledge_dir.mkdir(exist_ok=True, parents=True)
23
+
24
+ self.documents = []
25
+ self.texts = []
26
+ self.vectorstore = None
27
+ self.retriever = None
28
+ self.qa_chain = None
29
+ self.llm = None
30
+
31
+ self.device = "cpu" # For HF Spaces, CPU only
32
+
33
+ # Initialize embeddings
34
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
+
36
+ # Load and prepare knowledge
37
+ self.load_documents()
38
+ self.create_vectorstore()
39
+ self.init_llm()
40
+ self.init_qa_chain()
41
+
42
+ def load_documents(self):
43
+ # Load text files and split into chunks
44
+ files = list(self.knowledge_dir.glob("*.txt"))
45
+ self.documents = []
46
+ for file in files:
47
+ loader = TextLoader(str(file), encoding="utf-8")
48
+ docs = loader.load()
49
+ self.documents.extend(docs)
50
+
51
+ # Split into smaller chunks (to improve retrieval granularity)
52
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
53
+ self.texts = text_splitter.split_documents(self.documents)
54
+
55
+ def create_vectorstore(self):
56
+ if not self.texts:
57
+ self.vectorstore = None
58
+ return
59
+ self.vectorstore = FAISS.from_documents(self.texts, self.embeddings)
60
+ self.retriever = self.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
61
 
62
+ def init_llm(self):
63
+ # Initialize HuggingFace pipeline + LangChain wrapper LLM
 
 
64
 
65
+ # Try flan-t5-small first
 
66
  try:
67
+ pipe = pipeline(
68
+ "text2text-generation",
69
+ model="google/flan-t5-small",
70
+ device=-1, # CPU only
71
+ max_length=256,
72
+ do_sample=False,
73
  )
74
+ self.llm = HuggingFacePipeline(pipeline=pipe)
75
  except Exception as e:
76
+ print(f"Failed to load flan-t5-small: {e}")
77
+ self.llm = None
78
+
79
+ # Fallback: if no LLM, set to None and warn
80
+ if self.llm is None:
81
+ print("No LLM available, will fallback to retrieval-only.")
82
+
83
+ def init_qa_chain(self):
84
+ if self.llm and self.retriever:
85
+ self.qa_chain = RetrievalQA.from_chain_type(
86
+ llm=self.llm,
87
+ retriever=self.retriever,
88
+ return_source_documents=True,
89
+ chain_type="stuff", # Stuff all docs in prompt, or "map_reduce"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
+ else:
92
+ self.qa_chain = None
93
+
94
+ def get_knowledge_summary(self) -> str:
95
+ count = len(self.texts) if self.texts else 0
96
+ return f"{count} document chunks loaded."
97
+
98
+ def query(self, question: str):
99
+ if self.qa_chain:
100
+ # Use LLM + retrieval
101
+ result = self.qa_chain({"query": question})
102
+ answer = result.get("result", "No answer found.")
103
+ sources = result.get("source_documents", [])
104
+ source_texts = [doc.page_content for doc in sources]
105
+ return answer, source_texts
106
+ elif self.retriever:
107
+ # Retrieval only fallback
108
+ docs = self.retriever.get_relevant_documents(question)
109
+ answers = [doc.page_content for doc in docs]
110
+ return "\n\n".join(answers), []
111
+ else:
112
+ return "Knowledge base not initialized.", []