File size: 5,669 Bytes
d97a6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
    PyPDFLoader,
    DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import queue


class SearchableIndex:
    def __init__(self, path):
        self.path = path

    def get_text_splits(self):
        with open(self.path, 'r') as txt:
            data = txt.read()

        text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                    chunk_overlap=0,
                                                    length_function=len)
        doc_list = text_split.split_text(data)
        return doc_list

    def get_pdf_splits(self):
        loader = PyPDFLoader(self.path)
        pages = loader.load_and_split()
        text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                    chunk_overlap=0,
                                                    length_function=len)
        doc_list = []
        for pg in pages:
            pg_splits = text_split.split_text(pg.page_content)
            doc_list.extend(pg_splits)
        return doc_list

    def get_xml_splits(self, target_col, sheet_name):
        df = pd.read_excel(io=self.path,
                           engine='openpyxl',
                           sheet_name=sheet_name)

        df_loader = DataFrameLoader(df,
                                    page_content_column=target_col)

        excel_docs = df_loader.load()

        return excel_docs

    def get_csv_splits(self):
        csv_loader = CSVLoader(self.path)
        csv_docs = csv_loader.load()
        return csv_docs

    @classmethod
    def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
        if os.path.exists(index_store):
            local_db = FAISS.load_local(index_store, embeddings)
            local_db.merge_from(faiss_db)
            logger.info("Merge index completed")
            local_db.save_local(index_store)
            return local_db
        else:
            faiss_db.save_local(folder_path=index_store)
            logger.info("New store created and loaded...")
            local_db = FAISS.load_local(index_store, embeddings)
            return local_db

    @classmethod
    def check_and_load_index(cls, index_files, embeddings, logger, path, result_queue):
        if index_files:
            local_db = FAISS.load_local(index_files[0], embeddings)
            file_to_remove = os.path.join(path, 'combined_content.txt')
            if os.path.exists(file_to_remove):
                os.remove(file_to_remove)
        else:
            raise logger.warning("Index store does not exist")
        result_queue.put(local_db)  # Put the result in the queue

    @classmethod
    def embed_index(cls, url, path, target_col=None, sheet_name=None):
        embeddings = OpenAIEmbeddings()

        def process_docs(queues, extension):
            nonlocal doc_list
            instance = cls(path)
            if extension == ".txt":
                doc_list = instance.get_text_splits()
            elif extension == ".pdf":
                doc_list = instance.get_pdf_splits()
            elif extension == ".xml":
                doc_list = instance.get_xml_splits(target_col, sheet_name)
            elif extension == ".csv":
                doc_list = instance.get_csv_splits()
            else:
                doc_list = None
            queues.put(doc_list)

        if url != 'NO_URL' and path:
            file_extension = os.path.splitext(path)[1].lower()
            data_queue = queue.Queue()
            thread = threading.Thread(target=process_docs, args=(data_queue, file_extension))
            thread.start()
            doc_list = data_queue.get()
            if not doc_list:
                raise ValueError("Unsupported file format")

            faiss_db = FAISS.from_texts(doc_list, embeddings)
            index_store = os.path.splitext(path)[0] + "_index"
            local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
            return local_db, index_store
        elif url == 'NO_URL' and path:
            index_files = glob.glob(os.path.join(path, '*_index'))

            result_queue = queue.Queue()  # Create a queue to store the result

            thread = threading.Thread(target=cls.check_and_load_index,
                                      args=(index_files, embeddings, logger, path, result_queue))
            thread.start()
            local_db = result_queue.get()  # Retrieve the result from the queue
            return local_db

    @classmethod
    def query(cls, question: str, llm, index):
        """Query the vectorstore."""
        llm = llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
        chain = RetrievalQA.from_chain_type(
            llm, retriever=index.as_retriever()
        )
        return chain.run(question)


if __name__ == '__main__':
    pass
    # Examples for search query
    # index = SearchableIndex.embed_index(
    #     path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
    # prompt = 'show more detail about types of data collected'
    # llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
    # result = SearchableIndex.query(prompt, llm=llm, index=index)
    # print(result)