Demo / app /app.py
HanLee's picture
chore: linting
bb87055
raw
history blame
6.58 kB
# Chroma compatibility issue resolution
# https://docs.trychroma.com/troubleshooting#sqlite
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
from tempfile import NamedTemporaryFile
from typing import List
import chainlit as cl
from chainlit.types import AskFileResponse
import chromadb
from chromadb.config import Settings
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PDFPlumberLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.vectorstores.base import VectorStore
from prompt import EXAMPLE_PROMPT, PROMPT
def process_file(*, file: AskFileResponse) -> List[Document]:
"""Processes one PDF file from a Chainlit AskFileResponse object by first
loading the PDF document and then chunk it into sub documents. Only
supports PDF files.
Args:
file (AskFileResponse): input file to be processed
Raises:
ValueError: when we fail to process PDF files. We consider PDF file
processing failure when there's no text returned. For example, PDFs
with only image contents, corrupted PDFs, etc.
Returns:
List[Document]: List of Document(s). Each individual document has two
fields: page_content(string) and metadata(dict).
"""
if file.type != "application/pdf":
raise TypeError("Only PDF files are supported")
with NamedTemporaryFile() as tempfile:
tempfile.write(file.content)
loader = PDFPlumberLoader(tempfile.name)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=3000, chunk_overlap=100
)
docs = text_splitter.split_documents(documents)
# Adding source_id into the metadata to denote which document it is
for i, doc in enumerate(docs):
doc.metadata["source"] = f"source_{i}"
if not docs:
raise ValueError("PDF file parsing failed.")
return docs
def create_search_engine(
*, docs: List[Document], embeddings: Embeddings
) -> VectorStore:
"""Takes a list of Langchain Documents and an embedding model API wrapper
and build a search index using a VectorStore.
Args:
docs (List[Document]): List of Langchain Documents to be indexed into
the search engine.
embeddings (Embeddings): encoder model API used to calculate embedding
Returns:
VectorStore: Langchain VectorStore
"""
# Initialize Chromadb client to enable resetting and disable telemtry
client = chromadb.EphemeralClient()
client_settings = Settings(allow_reset=True, anonymized_telemetry=False)
# Reset the search engine to ensure we don't use old copies.
# NOTE: we do not need this for production
search_engine = Chroma(client=client, client_settings=client_settings)
search_engine._client.reset()
search_engine = Chroma.from_documents(
client=client,
documents=docs,
embedding=embeddings,
client_settings=client_settings,
)
return search_engine
@cl.on_chat_start
async def on_chat_start():
"""This function is written to prepare the environments for the chat
with PDF application. It should be decorated with cl.on_chat_start.
Returns:
None
"""
# Asking user to to upload a PDF to chat with
files = None
while files is None:
files = await cl.AskFileMessage(
content="Please Upload the PDF file you want to chat with...",
accept=["application/pdf"],
max_size_mb=20,
).send()
file = files[0]
# Process and save data in the user session
msg = cl.Message(content=f"Processing `{file.name}`...")
await msg.send()
docs = process_file(file=file)
cl.user_session.set("docs", docs)
msg.content = f"`{file.name}` processed. Loading ..."
await msg.update()
# Indexing documents into our search engine
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
try:
search_engine = await cl.make_async(create_search_engine)(
docs=docs, embeddings=embeddings
)
except Exception as e:
await cl.Message(content=f"Error: {e}").send()
raise SystemError
msg.content = f"`{file.name}` loaded. You can now ask questions!"
await msg.update()
model = ChatOpenAI(
model="gpt-3.5-turbo-16k-0613", temperature=0, streaming=True
)
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=model,
chain_type="stuff",
retriever=search_engine.as_retriever(max_tokens_limit=4097),
chain_type_kwargs={"prompt": PROMPT, "document_prompt": EXAMPLE_PROMPT},
)
# We are saving the chain in user_session, so we do not have to rebuild
# it every single time.
cl.user_session.set("chain", chain)
@cl.on_message
async def main(message: cl.Message):
# Let's load the chain from user_session
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain
response = await chain.acall(
message.content,
callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)],
)
answer = response["answer"]
sources = response["sources"].strip()
# Get all of the documents from user session
docs = cl.user_session.get("docs")
metadatas = [doc.metadata for doc in docs]
all_sources = [m["source"] for m in metadatas]
# Adding sources to the answer
source_elements = []
if sources:
found_sources = []
# Add the sources to the message
for source in sources.split(","):
source_name = source.strip().replace(".", "")
# Get the index of the source
try:
index = all_sources.index(source_name)
except ValueError:
continue
text = docs[index].page_content
found_sources.append(source_name)
# Create the text element referenced in the message
source_elements.append(cl.Text(content=text, name=source_name))
if found_sources:
answer += f"\nSources: {', '.join(found_sources)}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=source_elements).send()