Zeta / app.py
Ritvik19's picture
Upload app.py
7e3ebe9 verified
raw
history blame
3.49 kB
import os
from pathlib import Path
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.llms.openai import OpenAIChat, OpenAI
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain_experimental.text_splitter import SemanticChunker
import streamlit as st
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")
def load_documents():
loaders = [
PyPDFLoader(source_doc_url)
if source_doc_url.endswith(".pdf")
else WebBaseLoader(source_doc_url)
for source_doc_url in st.session_state.source_doc_urls
]
documents = []
for loader in loaders:
documents.extend(loader.load())
return documents
def split_documents(documents):
text_splitter = SemanticChunker(OpenAIEmbeddings(temperature=0))
texts = text_splitter.split_documents(documents)
return texts
def embeddings_on_local_vectordb(texts):
vectordb = Chroma.from_documents(
texts,
embedding=OpenAIEmbeddings(temperature=0),
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
)
vectordb.persist()
retriever = ContextualCompressionRetriever(
base_compressor=LLMChainExtractor.from_llm(OpenAI(temperature=0)),
base_retriever=vectordb.as_retriever(search_kwargs={"k": 3}, search_type="mmr"),
)
return retriever
def query_llm(retriever, query):
qa_chain = ConversationalRetrievalChain.from_llm(
llm=OpenAIChat(temperature=0),
retriever=retriever,
return_source_documents=True,
chain_type="refine",
)
relevant_docs = retriever.get_relevant_documents(query)
result = qa_chain({"question": query, "chat_history": st.session_state.messages})
result = result["answer"]
st.session_state.messages.append((query, result))
return relevant_docs, result
def input_fields():
os.environ["OPENAI_API_KEY"] = "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"
st.session_state.source_doc_urls = [
url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",")
]
def process_documents():
try:
documents = load_documents()
texts = split_documents(documents)
st.session_state.retriever = embeddings_on_local_vectordb(texts)
except Exception as e:
st.error(f"An error occurred: {e}")
def boot():
st.title("Enigma Chatbot")
input_fields()
st.sidebar.button("Submit Documents", on_click=process_documents)
st.sidebar.write("---")
st.sidebar.write("References made during the chat will appear here")
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message("human").write(message[0])
st.chat_message("ai").write(message[1])
if query := st.chat_input():
st.chat_message("human").write(query)
references, response = query_llm(st.session_state.retriever, query)
for doc in references:
st.sidebar.info(f"Page {doc.metadata['page']}\n\n{doc.page_content}")
st.chat_message("ai").write(response)
if __name__ == "__main__":
boot()