Spaces:
Build error
Build error
File size: 5,135 Bytes
d97a6fa e44f2dc d97a6fa 085b39c d97a6fa 085b39c d97a6fa e44f2dc d97a6fa e44f2dc 085b39c d97a6fa e44f2dc d97a6fa e44f2dc d97a6fa 085b39c d97a6fa e44f2dc d97a6fa e44f2dc 085b39c d97a6fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import (
PyPDFLoader,
DataFrameLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain.chat_models import ChatOpenAI
from bot.utils.show_log import logger
import pandas as pd
import threading
import glob
import os
import asyncio
import queue
class Query:
def __init__(self, question, llm, index):
self.question = question
self.llm = llm
self.index = index
def query(self):
"""Query the vectorstore."""
llm = self.llm or ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
chain = RetrievalQA.from_chain_type(
llm, retriever=self.index.as_retriever()
)
return chain.run(self.question)
class SearchableIndex:
def __init__(self, path):
self.path = path
@classmethod
def get_splits(cls, path, target_col=None, sheet_name=None):
extension = os.path.splitext(path)[1].lower()
doc_list = None
if extension == ".txt":
with open(path, 'r') as txt:
data = txt.read()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = text_split.split_text(data)
elif extension == ".pdf":
loader = PyPDFLoader(path)
pages = loader.load_and_split()
text_split = RecursiveCharacterTextSplitter(chunk_size=1000,
chunk_overlap=0,
length_function=len)
doc_list = []
for pg in pages:
pg_splits = text_split.split_text(pg.page_content)
doc_list.extend(pg_splits)
elif extension == ".xml":
df = pd.read_excel(io=path, engine='openpyxl', sheet_name=sheet_name)
df_loader = DataFrameLoader(df, page_content_column=target_col)
doc_list = df_loader.load()
elif extension == ".csv":
csv_loader = CSVLoader(path)
doc_list = csv_loader.load()
if doc_list is None:
raise ValueError("Unsupported file format")
return doc_list
@classmethod
def merge_or_create_index(cls, index_store, faiss_db, embeddings, logger):
if os.path.exists(index_store):
local_db = FAISS.load_local(index_store, embeddings)
local_db.merge_from(faiss_db)
operation_info = "Merge"
else:
local_db = faiss_db # Use the provided faiss_db directly for a new store
operation_info = "New store creation"
local_db.save_local(index_store)
logger.info(f"{operation_info} index completed")
return local_db
@classmethod
def load_index(cls, index_files, embeddings, logger):
if index_files:
return FAISS.load_local(index_files[0], embeddings)
logger.warning("Index store does not exist")
return None
@classmethod
def check_and_load_index(cls, index_files, embeddings, logger, result_queue):
local_db = cls.load_index(index_files, embeddings, logger)
result_queue.put(local_db)
@classmethod
def load_index_asynchronously(cls, index_files, embeddings, logger):
result_queue = queue.Queue()
thread = threading.Thread(
target=cls.check_and_load_index,
args=(index_files, embeddings, logger, result_queue)
)
thread.start()
thread.join() # Wait for the thread to finish
return result_queue.get()
@classmethod
def embed_index(cls, url, path, llm, prompt, target_col=None, sheet_name=None):
embeddings = OpenAIEmbeddings()
if path:
if url != 'NO_URL':
doc_list = cls.get_splits(path, target_col, sheet_name)
faiss_db = FAISS.from_texts(doc_list, embeddings)
index_store = os.path.splitext(path)[0] + "_index"
local_db = cls.merge_or_create_index(index_store, faiss_db, embeddings, logger)
return Query(prompt, llm, local_db)
index_files = glob.glob(os.path.join(path, '*_index'))
local_db = cls.load_index_asynchronously(index_files, embeddings, logger)
return Query(prompt, llm, local_db)
if __name__ == '__main__':
pass
# Examples for search query
# index = SearchableIndex.embed_index(
# path="/Users/macbook/Downloads/AI_test_exam/ChatBot/learning_documents/combined_content.txt")
# prompt = 'show more detail about types of data collected'
# llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
# result = SearchableIndex.query(prompt, llm=llm, index=index)
# print(result)
|