aamirhameed commited on
Commit
1aba791
·
verified ·
1 Parent(s): e08ac3a

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +44 -46
knowledge_engine.py CHANGED
@@ -1,61 +1,59 @@
1
  import os
2
- from pathlib import Path
3
- from langchain.document_loaders import TextLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import FAISS
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
  from langchain.chains import RetrievalQA
8
- from langchain.llms import HuggingFaceHub
 
9
 
10
  class KnowledgeManager:
11
- def __init__(self, knowledge_dir="."): # root dir by default
12
- self.knowledge_dir = Path(knowledge_dir)
13
- self.documents = []
14
- self.embeddings = None
15
- self.vectorstore = None
16
- self.retriever = None
17
- self.llm = None
18
  self.qa_chain = None
 
19
 
20
- self._load_documents()
21
- if self.documents:
22
- self._initialize_embeddings()
23
- self._initialize_vectorstore()
24
- self._initialize_llm()
25
- self._initialize_qa_chain()
26
-
27
- def _load_documents(self):
28
- if not self.knowledge_dir.exists():
29
- raise FileNotFoundError(f"Directory {self.knowledge_dir} does not exist.")
30
-
31
- files = list(self.knowledge_dir.glob("*.txt"))
32
- if not files:
33
- raise FileNotFoundError(f"No .txt files found in {self.knowledge_dir}. Please upload your knowledge base files in root.")
34
-
35
- for file in files:
36
- loader = TextLoader(str(file))
37
- self.documents.extend(loader.load())
38
-
39
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
40
- self.documents = splitter.split_documents(self.documents)
41
 
42
  def _initialize_embeddings(self):
43
- self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
44
 
45
- def _initialize_vectorstore(self):
46
- self.vectorstore = FAISS.from_documents(self.documents, self.embeddings)
47
- self.retriever = self.vectorstore.as_retriever()
48
 
49
- def _initialize_llm(self):
50
- self.llm = HuggingFaceHub(repo_id="google/flan-t5-small", model_kwargs={"temperature":0, "max_length":256})
51
 
52
- def _initialize_qa_chain(self):
53
- self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=self.retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def ask(self, query):
56
  if not self.qa_chain:
57
- return "Knowledge base not initialized properly."
58
- return self.qa_chain.run(query)
59
-
60
- def get_knowledge_summary(self):
61
- return f"Loaded {len(self.documents)} document chunks from {self.knowledge_dir}"
 
1
  import os
 
 
 
2
  from langchain.vectorstores import FAISS
3
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
4
  from langchain.chains import RetrievalQA
5
+ from langchain.llms import HuggingFacePipeline
6
+ from transformers import pipeline
7
 
8
  class KnowledgeManager:
9
+ def __init__(self, root_dir="."):
10
+ self.root_dir = root_dir
11
+ self.docsearch = None
 
 
 
 
12
  self.qa_chain = None
13
+ self.llm = None
14
 
15
+ self._initialize_llm()
16
+ self._initialize_embeddings()
17
+ self._load_knowledge_base()
18
+
19
+ def _initialize_llm(self):
20
+ local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256)
21
+ self.llm = HuggingFacePipeline(pipeline=local_pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def _initialize_embeddings(self):
24
+ self.embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
25
 
26
+ def _load_knowledge_base(self):
27
+ # Find all .txt files in root directory
28
+ txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
29
 
30
+ if not txt_files:
31
+ raise FileNotFoundError("No .txt files found in root directory.")
32
 
33
+ all_texts = []
34
+ for filename in txt_files:
35
+ path = os.path.join(self.root_dir, filename)
36
+ with open(path, "r", encoding="utf-8") as f:
37
+ content = f.read()
38
+ all_texts.append(content)
39
+
40
+ full_text = "\n\n".join(all_texts)
41
+
42
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
43
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
44
+ docs = text_splitter.create_documents([full_text])
45
+
46
+ self.docsearch = FAISS.from_documents(docs, self.embeddings)
47
+
48
+ self.qa_chain = RetrievalQA.from_chain_type(
49
+ llm=self.llm,
50
+ chain_type="stuff",
51
+ retriever=self.docsearch.as_retriever(),
52
+ return_source_documents=True,
53
+ )
54
 
55
  def ask(self, query):
56
  if not self.qa_chain:
57
+ raise ValueError("Knowledge base not initialized.")
58
+ result = self.qa_chain(query)
59
+ return result['result']