Spaces:
Build error
Build error
File size: 5,669 Bytes
d97a6fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
PyPDFLoader,
DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import queue
class SearchableIndex:
def __init__(self, path):
self.path = path
def get_text_splits(self):
with open(self.path, 'r') as txt:
data = txt.read()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = text_split.split_text(data)
return doc_list
def get_pdf_splits(self):
loader = PyPDFLoader(self.path)
pages = loader.load_and_split()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = []
for pg in pages:
pg_splits = text_split.split_text(pg.page_content)
doc_list.extend(pg_splits)
return doc_list
def get_xml_splits(self, target_col, sheet_name):
df = pd.read_excel(io=self.path,
engine='openpyxl',
sheet_name=sheet_name)
df_loader = DataFrameLoader(df,
page_content_column=target_col)
excel_docs = df_loader.load()
return excel_docs
def get_csv_splits(self):
csv_loader = CSVLoader(self.path)
csv_docs = csv_loader.load()
return csv_docs
@classmethod
def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
if os.path.exists(index_store):
local_db = FAISS.load_local(index_store, embeddings)
local_db.merge_from(faiss_db)
logger.info("Merge index completed")
local_db.save_local(index_store)
return local_db
else:
faiss_db.save_local(folder_path=index_store)
logger.info("New store created and loaded...")
local_db = FAISS.load_local(index_store, embeddings)
return local_db
@classmethod
def check_and_load_index(cls, index_files, embeddings, logger, path, result_queue):
if index_files:
local_db = FAISS.load_local(index_files[0], embeddings)
file_to_remove = os.path.join(path, 'combined_content.txt')
if os.path.exists(file_to_remove):
os.remove(file_to_remove)
else:
raise logger.warning("Index store does not exist")
result_queue.put(local_db) # Put the result in the queue
@classmethod
def embed_index(cls, url, path, target_col=None, sheet_name=None):
embeddings = OpenAIEmbeddings()
def process_docs(queues, extension):
nonlocal doc_list
instance = cls(path)
if extension == ".txt":
doc_list = instance.get_text_splits()
elif extension == ".pdf":
doc_list = instance.get_pdf_splits()
elif extension == ".xml":
doc_list = instance.get_xml_splits(target_col, sheet_name)
elif extension == ".csv":
doc_list = instance.get_csv_splits()
else:
doc_list = None
queues.put(doc_list)
if url != 'NO_URL' and path:
file_extension = os.path.splitext(path)[1].lower()
data_queue = queue.Queue()
thread = threading.Thread(target=process_docs, args=(data_queue, file_extension))
thread.start()
doc_list = data_queue.get()
if not doc_list:
raise ValueError("Unsupported file format")
faiss_db = FAISS.from_texts(doc_list, embeddings)
index_store = os.path.splitext(path)[0] + "_index"
local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
return local_db, index_store
elif url == 'NO_URL' and path:
index_files = glob.glob(os.path.join(path, '*_index'))
result_queue = queue.Queue() # Create a queue to store the result
thread = threading.Thread(target=cls.check_and_load_index,
args=(index_files, embeddings, logger, path, result_queue))
thread.start()
local_db = result_queue.get() # Retrieve the result from the queue
return local_db
@classmethod
def query(cls, question: str, llm, index):
"""Query the vectorstore."""
llm = llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
chain = RetrievalQA.from_chain_type(
llm, retriever=index.as_retriever()
)
return chain.run(question)
if __name__ == '__main__':
pass
# Examples for search query
# index = SearchableIndex.embed_index(
# path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
# prompt = 'show more detail about types of data collected'
# llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
# result = SearchableIndex.query(prompt, llm=llm, index=index)
# print(result)
|