from langchain.vectorstores import FAISS from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.document_loaders import ( PyPDFLoader, DataFrameLoader, ) from langchain.document_loaders.csv_loader import CSVLoader from langchain.embeddings.openai import OpenAIEmbeddings from langchain.chains.retrieval_qa.base import RetrievalQA from langchain.chat_models import ChatOpenAI from bot.utils.show_log import logger import pandas as pd import threading import glob import os import queue class SearchableIndex: def __init__(self, path): self.path = path def get_text_splits(self): with open(self.path, 'r') as txt: data = txt.read() text_split = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0, length_function=len) doc_list = text_split.split_text(data) return doc_list def get_pdf_splits(self): loader = PyPDFLoader(self.path) pages = loader.load_and_split() text_split = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0, length_function=len) doc_list = [] for pg in pages: pg_splits = text_split.split_text(pg.page_content) doc_list.extend(pg_splits) return doc_list def get_xml_splits(self, target_col, sheet_name): df = pd.read_excel(io=self.path, engine='openpyxl', sheet_name=sheet_name) df_loader = DataFrameLoader(df, page_content_column=target_col) excel_docs = df_loader.load() return excel_docs def get_csv_splits(self): csv_loader = CSVLoader(self.path) csv_docs = csv_loader.load() return csv_docs @classmethod def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger): if os.path.exists(index_store): local_db = FAISS.load_local(index_store, embeddings) local_db.merge_from(faiss_db) logger.info("Merge index completed") local_db.save_local(index_store) return local_db else: faiss_db.save_local(folder_path=index_store) logger.info("New store created and loaded...") local_db = FAISS.load_local(index_store, embeddings) return local_db @classmethod def check_and_load_index(cls, index_files, embeddings, logger, path, result_queue): if index_files: local_db = FAISS.load_local(index_files[0], embeddings) file_to_remove = os.path.join(path, 'combined_content.txt') if os.path.exists(file_to_remove): os.remove(file_to_remove) else: raise logger.warning("Index store does not exist") result_queue.put(local_db) # Put the result in the queue @classmethod def embed_index(cls, url, path, target_col=None, sheet_name=None): embeddings = OpenAIEmbeddings() def process_docs(queues, extension): nonlocal doc_list instance = cls(path) if extension == ".txt": doc_list = instance.get_text_splits() elif extension == ".pdf": doc_list = instance.get_pdf_splits() elif extension == ".xml": doc_list = instance.get_xml_splits(target_col, sheet_name) elif extension == ".csv": doc_list = instance.get_csv_splits() else: doc_list = None queues.put(doc_list) if url != 'NO_URL' and path: file_extension = os.path.splitext(path)[1].lower() data_queue = queue.Queue() thread = threading.Thread(target=process_docs, args=(data_queue, file_extension)) thread.start() doc_list = data_queue.get() if not doc_list: raise ValueError("Unsupported file format") faiss_db = FAISS.from_texts(doc_list, embeddings) index_store = os.path.splitext(path)[0] + "_index" local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger) return local_db, index_store elif url == 'NO_URL' and path: index_files = glob.glob(os.path.join(path, '*_index')) result_queue = queue.Queue() # Create a queue to store the result thread = threading.Thread(target=cls.check_and_load_index, args=(index_files, embeddings, logger, path, result_queue)) thread.start() local_db = result_queue.get() # Retrieve the result from the queue return local_db @classmethod def query(cls, question: str, llm, index): """Query the vectorstore.""" llm = llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0) chain = RetrievalQA.from_chain_type( llm, retriever=index.as_retriever() ) return chain.run(question) if __name__ == '__main__': pass # Examples for search query # index = SearchableIndex.embed_index( # path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt") # prompt = 'show more detail about types of data collected' # llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0) # result = SearchableIndex.query(prompt, llm=llm, index=index) # print(result)