aamirhameed commited on
Commit
719919b
·
verified ·
1 Parent(s): 4b47a9f

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +24 -18
knowledge_engine.py CHANGED
@@ -6,26 +6,40 @@ from concurrent.futures import ThreadPoolExecutor
6
 
7
  from config import Config
8
 
 
 
 
9
  # Core ML/AI libraries
10
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
  from langchain_community.vectorstores import FAISS
13
- from langchain_community.embeddings import OllamaEmbeddings
14
  from langchain.chains import RetrievalQA
15
  from langchain.prompts import PromptTemplate
16
- from langchain_community.llms import Ollama
17
- from langchain_community.retrievers import BM25Retriever
18
 
 
 
 
19
 
20
- class KnowledgeManager:
21
- """Main knowledge management class handling document processing and Q&A with CoT & MoE routing"""
22
 
 
23
  def __init__(self):
24
  Config.setup_dirs()
25
- self.embeddings = OllamaEmbeddings(model="mxbai-embed-large")
26
  self.vector_db, self.bm25_retriever = self._init_retrievers()
27
  self.qa_chain = self._create_moe_qa_chain()
28
 
 
 
 
 
 
 
 
 
 
 
 
29
  def _init_retrievers(self):
30
  faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
31
  faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl"
@@ -42,7 +56,7 @@ class KnowledgeManager:
42
  bm25_retriever = pickle.load(f)
43
  return vector_db, bm25_retriever
44
  except Exception as e:
45
- print(f"[!] Error loading existing vector store: {e}. Rebuilding...")
46
 
47
  return self._build_retrievers_from_documents()
48
 
@@ -77,18 +91,15 @@ class KnowledgeManager:
77
  return vector_db, bm25_retriever
78
 
79
  def _create_default_knowledge(self):
80
- default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, Ollama, BM25 Integration"""
81
  with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
82
  f.write(default_text)
83
 
84
  def _parallel_retrieve(self, question: str):
85
- """Parallel retrieval execution: simulates Mixture of Experts routing"""
86
-
87
  def retrieve_with_bm25():
88
  return self.bm25_retriever.get_relevant_documents(question)
89
 
90
  def retrieve_with_vector():
91
- # Lowered threshold to 0.3 for better doc retrieval (adjust as needed)
92
  retriever = self.vector_db.as_retriever(
93
  search_type="similarity_score_threshold",
94
  search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS, "score_threshold": 0.83}
@@ -101,7 +112,6 @@ class KnowledgeManager:
101
  bm25_results = bm25_future.result()
102
  vector_results = vector_future.result()
103
 
104
- # Combine results; duplicates are possible, consider deduplication if needed
105
  return vector_results + bm25_results
106
 
107
  def _create_moe_qa_chain(self):
@@ -123,9 +133,9 @@ Instructions:
123
  Answer:"""
124
 
125
  return RetrievalQA.from_chain_type(
126
- llm=Ollama(model="phi", temperature=0.1),
127
  chain_type="stuff",
128
- retriever=self.vector_db.as_retriever(search_kwargs={"k": 1}), # Dummy retriever to satisfy LangChain
129
  chain_type_kwargs={
130
  "prompt": PromptTemplate(
131
  template=prompt_template,
@@ -136,7 +146,6 @@ Answer:"""
136
  )
137
 
138
  def query(self, question: str) -> Dict[str, Any]:
139
- """Query system using CoT + MoE logic"""
140
  if not self.qa_chain:
141
  return {
142
  "answer": "Knowledge system not initialized. Please reload.",
@@ -148,14 +157,11 @@ Answer:"""
148
  start_time = datetime.now()
149
  docs = self._parallel_retrieve(question)
150
 
151
- # If no docs found, fallback to retriever without threshold for testing
152
  if not docs:
153
  retriever = self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS})
154
  docs = retriever.get_relevant_documents(question)
155
 
156
- # Use invoke() for chains with multiple outputs
157
  result = self.qa_chain.invoke({"input_documents": docs, "query": question})
158
-
159
  processing_time = (datetime.now() - start_time).total_seconds() * 1000
160
 
161
  return {
 
6
 
7
  from config import Config
8
 
9
+ # Setup Hugging Face token securely (Make sure to set this in your environment securely)
10
+ # os.environ["HUGGINGFACEHUB_API_TOKEN"] = "your_token_here"
11
+
12
  # Core ML/AI libraries
13
  from langchain_community.document_loaders import TextLoader, DirectoryLoader
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
15
  from langchain_community.vectorstores import FAISS
 
16
  from langchain.chains import RetrievalQA
17
  from langchain.prompts import PromptTemplate
18
+ from langchain.retrievers import BM25Retriever
 
19
 
20
+ # Only use Hugging Face embeddings and LLM (no Ollama fallback)
21
+ from langchain_community.embeddings import HuggingFaceEmbeddings
22
+ from langchain_community.llms import HuggingFaceHub
23
 
 
 
24
 
25
+ class KnowledgeManager:
26
  def __init__(self):
27
  Config.setup_dirs()
28
+ self.embeddings = self._init_embeddings()
29
  self.vector_db, self.bm25_retriever = self._init_retrievers()
30
  self.qa_chain = self._create_moe_qa_chain()
31
 
32
+ def _init_embeddings(self):
33
+ print("[i] Using Hugging Face embeddings")
34
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
35
+
36
+ def _init_llm(self):
37
+ print("[i] Using Hugging Face LLM")
38
+ return HuggingFaceHub(
39
+ repo_id="tiiuae/falcon-7b-instruct",
40
+ model_kwargs={"temperature": 0.1, "max_new_tokens": 512}
41
+ )
42
+
43
  def _init_retrievers(self):
44
  faiss_index_path = Config.VECTOR_STORE_PATH / "index.faiss"
45
  faiss_pkl_path = Config.VECTOR_STORE_PATH / "index.pkl"
 
56
  bm25_retriever = pickle.load(f)
57
  return vector_db, bm25_retriever
58
  except Exception as e:
59
+ print(f"[!] Error loading vector store: {e}. Rebuilding...")
60
 
61
  return self._build_retrievers_from_documents()
62
 
 
91
  return vector_db, bm25_retriever
92
 
93
  def _create_default_knowledge(self):
94
+ default_text = """Sirraya xBrain - Advanced AI Platform\n\nCreated by Amir Hameed.\n\nFeatures:\n- Hybrid Retrieval (Vector + BM25)\n- LISA Assistant\n- FAISS, BM25 Integration"""
95
  with open(Config.KNOWLEDGE_DIR / "sirraya_xbrain.txt", "w", encoding="utf-8") as f:
96
  f.write(default_text)
97
 
98
  def _parallel_retrieve(self, question: str):
 
 
99
  def retrieve_with_bm25():
100
  return self.bm25_retriever.get_relevant_documents(question)
101
 
102
  def retrieve_with_vector():
 
103
  retriever = self.vector_db.as_retriever(
104
  search_type="similarity_score_threshold",
105
  search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS, "score_threshold": 0.83}
 
112
  bm25_results = bm25_future.result()
113
  vector_results = vector_future.result()
114
 
 
115
  return vector_results + bm25_results
116
 
117
  def _create_moe_qa_chain(self):
 
133
  Answer:"""
134
 
135
  return RetrievalQA.from_chain_type(
136
+ llm=self._init_llm(),
137
  chain_type="stuff",
138
+ retriever=self.vector_db.as_retriever(search_kwargs={"k": 1}),
139
  chain_type_kwargs={
140
  "prompt": PromptTemplate(
141
  template=prompt_template,
 
146
  )
147
 
148
  def query(self, question: str) -> Dict[str, Any]:
 
149
  if not self.qa_chain:
150
  return {
151
  "answer": "Knowledge system not initialized. Please reload.",
 
157
  start_time = datetime.now()
158
  docs = self._parallel_retrieve(question)
159
 
 
160
  if not docs:
161
  retriever = self.vector_db.as_retriever(search_kwargs={"k": Config.MAX_CONTEXT_CHUNKS})
162
  docs = retriever.get_relevant_documents(question)
163
 
 
164
  result = self.qa_chain.invoke({"input_documents": docs, "query": question})
 
165
  processing_time = (datetime.now() - start_time).total_seconds() * 1000
166
 
167
  return {