Spaces:
Sleeping
Sleeping
Update knowledge_engine.py
Browse files- knowledge_engine.py +44 -46
knowledge_engine.py
CHANGED
@@ -1,61 +1,59 @@
|
|
1 |
import os
|
2 |
-
from pathlib import Path
|
3 |
-
from langchain.document_loaders import TextLoader
|
4 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
from langchain.vectorstores import FAISS
|
6 |
-
from langchain.embeddings import
|
7 |
from langchain.chains import RetrievalQA
|
8 |
-
from langchain.llms import
|
|
|
9 |
|
10 |
class KnowledgeManager:
|
11 |
-
def __init__(self,
|
12 |
-
self.
|
13 |
-
self.
|
14 |
-
self.embeddings = None
|
15 |
-
self.vectorstore = None
|
16 |
-
self.retriever = None
|
17 |
-
self.llm = None
|
18 |
self.qa_chain = None
|
|
|
19 |
|
20 |
-
self.
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
def _load_documents(self):
|
28 |
-
if not self.knowledge_dir.exists():
|
29 |
-
raise FileNotFoundError(f"Directory {self.knowledge_dir} does not exist.")
|
30 |
-
|
31 |
-
files = list(self.knowledge_dir.glob("*.txt"))
|
32 |
-
if not files:
|
33 |
-
raise FileNotFoundError(f"No .txt files found in {self.knowledge_dir}. Please upload your knowledge base files in root.")
|
34 |
-
|
35 |
-
for file in files:
|
36 |
-
loader = TextLoader(str(file))
|
37 |
-
self.documents.extend(loader.load())
|
38 |
-
|
39 |
-
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
40 |
-
self.documents = splitter.split_documents(self.documents)
|
41 |
|
42 |
def _initialize_embeddings(self):
|
43 |
-
self.embeddings =
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
48 |
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def ask(self, query):
|
56 |
if not self.qa_chain:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
def get_knowledge_summary(self):
|
61 |
-
return f"Loaded {len(self.documents)} document chunks from {self.knowledge_dir}"
|
|
|
1 |
import os
|
|
|
|
|
|
|
2 |
from langchain.vectorstores import FAISS
|
3 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
4 |
from langchain.chains import RetrievalQA
|
5 |
+
from langchain.llms import HuggingFacePipeline
|
6 |
+
from transformers import pipeline
|
7 |
|
8 |
class KnowledgeManager:
|
9 |
+
def __init__(self, root_dir="."):
|
10 |
+
self.root_dir = root_dir
|
11 |
+
self.docsearch = None
|
|
|
|
|
|
|
|
|
12 |
self.qa_chain = None
|
13 |
+
self.llm = None
|
14 |
|
15 |
+
self._initialize_llm()
|
16 |
+
self._initialize_embeddings()
|
17 |
+
self._load_knowledge_base()
|
18 |
+
|
19 |
+
def _initialize_llm(self):
|
20 |
+
local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256)
|
21 |
+
self.llm = HuggingFacePipeline(pipeline=local_pipe)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def _initialize_embeddings(self):
|
24 |
+
self.embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
|
25 |
|
26 |
+
def _load_knowledge_base(self):
|
27 |
+
# Find all .txt files in root directory
|
28 |
+
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
|
29 |
|
30 |
+
if not txt_files:
|
31 |
+
raise FileNotFoundError("No .txt files found in root directory.")
|
32 |
|
33 |
+
all_texts = []
|
34 |
+
for filename in txt_files:
|
35 |
+
path = os.path.join(self.root_dir, filename)
|
36 |
+
with open(path, "r", encoding="utf-8") as f:
|
37 |
+
content = f.read()
|
38 |
+
all_texts.append(content)
|
39 |
+
|
40 |
+
full_text = "\n\n".join(all_texts)
|
41 |
+
|
42 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
43 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
44 |
+
docs = text_splitter.create_documents([full_text])
|
45 |
+
|
46 |
+
self.docsearch = FAISS.from_documents(docs, self.embeddings)
|
47 |
+
|
48 |
+
self.qa_chain = RetrievalQA.from_chain_type(
|
49 |
+
llm=self.llm,
|
50 |
+
chain_type="stuff",
|
51 |
+
retriever=self.docsearch.as_retriever(),
|
52 |
+
return_source_documents=True,
|
53 |
+
)
|
54 |
|
55 |
def ask(self, query):
|
56 |
if not self.qa_chain:
|
57 |
+
raise ValueError("Knowledge base not initialized.")
|
58 |
+
result = self.qa_chain(query)
|
59 |
+
return result['result']
|
|
|
|