File size: 5,135 Bytes
d97a6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e44f2dc
d97a6fa
 
 
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
 
085b39c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
 
 
 
 
 
e44f2dc
d97a6fa
e44f2dc
 
 
 
 
085b39c
d97a6fa
 
e44f2dc
d97a6fa
e44f2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97a6fa
 
085b39c
d97a6fa
 
e44f2dc
 
 
 
 
 
 
d97a6fa
e44f2dc
 
085b39c
d97a6fa
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
    PyPDFLoader,
    DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import asyncio
import queue


class Query:
    def __init__(self, question, llm, index):
        self.question = question
        self.llm = llm
        self.index = index

    def query(self):
        """Query the vectorstore."""
        llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
        chain = RetrievalQA.from_chain_type(
            llm, retriever=self.index.as_retriever()
        )
        return chain.run(self.question)


class SearchableIndex:
    def __init__(self, path):
        self.path = path

    @classmethod
    def get_splits(cls, path, target_col=None, sheet_name=None):
        extension = os.path.splitext(path)[1].lower()
        doc_list = None
        if extension == ".txt":
            with open(path, 'r') as txt:
                data = txt.read()
                text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                            chunk_overlap=0,
                                                            length_function=len)
                doc_list = text_split.split_text(data)
        elif extension == ".pdf":
            loader = PyPDFLoader(path)
            pages = loader.load_and_split()
            text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
                                                        chunk_overlap=0,
                                                        length_function=len)
            doc_list = []
            for pg in pages:
                pg_splits = text_split.split_text(pg.page_content)
                doc_list.extend(pg_splits)
        elif extension == ".xml":
            df = pd.read_excel(io=path, engine='openpyxl', sheet_name=sheet_name)
            df_loader = DataFrameLoader(df, page_content_column=target_col)
            doc_list = df_loader.load()
        elif extension == ".csv":
            csv_loader = CSVLoader(path)
            doc_list = csv_loader.load()
        if doc_list is None:
            raise ValueError("Unsupported file format")
        return doc_list

    @classmethod
    def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
        if os.path.exists(index_store):
            local_db = FAISS.load_local(index_store, embeddings)
            local_db.merge_from(faiss_db)
            operation_info = "Merge"
        else:
            local_db = faiss_db  # Use the provided faiss_db directly for a new store
            operation_info = "New store creation"

        local_db.save_local(index_store)
        logger.info(f"{operation_info} index completed")
        return local_db

    @classmethod
    def load_index(cls, index_files, embeddings, logger):
        if index_files:
            return FAISS.load_local(index_files[0], embeddings)
        logger.warning("Index store does not exist")
        return None

    @classmethod
    def check_and_load_index(cls, index_files, embeddings, logger, result_queue):
        local_db = cls.load_index(index_files, embeddings, logger)
        result_queue.put(local_db)

    @classmethod
    def load_index_asynchronously(cls, index_files, embeddings, logger):
        result_queue = queue.Queue()
        thread = threading.Thread(
            target=cls.check_and_load_index,
            args=(index_files, embeddings, logger, result_queue)
        )
        thread.start()
        thread.join()  # Wait for the thread to finish
        return result_queue.get()

    @classmethod
    def embed_index(cls, url, path, llm, prompt, target_col=None, sheet_name=None):
        embeddings = OpenAIEmbeddings()

        if path:
            if url != 'NO_URL':
                doc_list = cls.get_splits(path, target_col, sheet_name)
                faiss_db = FAISS.from_texts(doc_list, embeddings)
                index_store = os.path.splitext(path)[0] + "_index"
                local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
                return Query(prompt, llm, local_db)

            index_files = glob.glob(os.path.join(path, '*_index'))
            local_db = cls.load_index_asynchronously(index_files, embeddings, logger)
            return Query(prompt, llm, local_db)


if __name__ == '__main__':
    pass
    # Examples for search query
    # index = SearchableIndex.embed_index(
    #     path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
    # prompt = 'show more detail about types of data collected'
    # llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
    # result = SearchableIndex.query(prompt, llm=llm, index=index)
    # print(result)