|
|
|
|
|
__import__("pysqlite3") |
|
import sys |
|
|
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3") |
|
|
|
from tempfile import NamedTemporaryFile |
|
from typing import List |
|
|
|
import chainlit as cl |
|
from chainlit.types import AskFileResponse |
|
import chromadb |
|
from chromadb.config import Settings |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
from langchain.chat_models import ChatOpenAI |
|
from langchain.document_loaders import PDFPlumberLoader |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from langchain.schema import Document |
|
from langchain.schema.embeddings import Embeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.vectorstores.base import VectorStore |
|
|
|
from prompt import EXAMPLE_PROMPT, PROMPT |
|
|
|
|
|
def process_file(*, file: AskFileResponse) -> List[Document]: |
|
"""Processes one PDF file from a Chainlit AskFileResponse object by first |
|
loading the PDF document and then chunk it into sub documents. Only |
|
supports PDF files. |
|
|
|
Args: |
|
file (AskFileResponse): input file to be processed |
|
|
|
Raises: |
|
ValueError: when we fail to process PDF files. We consider PDF file |
|
processing failure when there's no text returned. For example, PDFs |
|
with only image contents, corrupted PDFs, etc. |
|
|
|
Returns: |
|
List[Document]: List of Document(s). Each individual document has two |
|
fields: page_content(string) and metadata(dict). |
|
""" |
|
if file.type != "application/pdf": |
|
raise TypeError("Only PDF files are supported") |
|
|
|
with NamedTemporaryFile() as tempfile: |
|
tempfile.write(file.content) |
|
|
|
loader = PDFPlumberLoader(tempfile.name) |
|
documents = loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=3000, chunk_overlap=100 |
|
) |
|
docs = text_splitter.split_documents(documents) |
|
|
|
|
|
for i, doc in enumerate(docs): |
|
doc.metadata["source"] = f"source_{i}" |
|
|
|
if not docs: |
|
raise ValueError("PDF file parsing failed.") |
|
|
|
return docs |
|
|
|
|
|
def create_search_engine( |
|
*, docs: List[Document], embeddings: Embeddings |
|
) -> VectorStore: |
|
"""Takes a list of Langchain Documents and an embedding model API wrapper |
|
and build a search index using a VectorStore. |
|
|
|
Args: |
|
docs (List[Document]): List of Langchain Documents to be indexed into |
|
the search engine. |
|
embeddings (Embeddings): encoder model API used to calculate embedding |
|
|
|
Returns: |
|
VectorStore: Langchain VectorStore |
|
""" |
|
|
|
client = chromadb.EphemeralClient() |
|
client_settings = Settings(allow_reset=True, anonymized_telemetry=False) |
|
|
|
|
|
|
|
search_engine = Chroma(client=client, client_settings=client_settings) |
|
search_engine._client.reset() |
|
search_engine = Chroma.from_documents( |
|
client=client, |
|
documents=docs, |
|
embedding=embeddings, |
|
client_settings=client_settings, |
|
) |
|
|
|
return search_engine |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
"""This function is written to prepare the environments for the chat |
|
with PDF application. It should be decorated with cl.on_chat_start. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
files = None |
|
while files is None: |
|
files = await cl.AskFileMessage( |
|
content="Please Upload the PDF file you want to chat with...", |
|
accept=["application/pdf"], |
|
max_size_mb=20, |
|
).send() |
|
file = files[0] |
|
|
|
|
|
msg = cl.Message(content=f"Processing `{file.name}`...") |
|
await msg.send() |
|
|
|
docs = process_file(file=file) |
|
cl.user_session.set("docs", docs) |
|
msg.content = f"`{file.name}` processed. Loading ..." |
|
await msg.update() |
|
|
|
|
|
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002") |
|
try: |
|
search_engine = await cl.make_async(create_search_engine)( |
|
docs=docs, embeddings=embeddings |
|
) |
|
except Exception as e: |
|
await cl.Message(content=f"Error: {e}").send() |
|
raise SystemError |
|
msg.content = f"`{file.name}` loaded. You can now ask questions!" |
|
await msg.update() |
|
|
|
model = ChatOpenAI( |
|
model="gpt-3.5-turbo-16k-0613", temperature=0, streaming=True |
|
) |
|
|
|
chain = RetrievalQAWithSourcesChain.from_chain_type( |
|
llm=model, |
|
chain_type="stuff", |
|
retriever=search_engine.as_retriever(max_tokens_limit=4097), |
|
chain_type_kwargs={"prompt": PROMPT, "document_prompt": EXAMPLE_PROMPT}, |
|
) |
|
|
|
|
|
|
|
cl.user_session.set("chain", chain) |
|
|
|
|
|
@cl.on_message |
|
async def main(message: cl.Message): |
|
|
|
chain = cl.user_session.get("chain") |
|
|
|
response = await chain.acall( |
|
message.content, |
|
callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)], |
|
) |
|
answer = response["answer"] |
|
sources = response["sources"].strip() |
|
|
|
|
|
docs = cl.user_session.get("docs") |
|
metadatas = [doc.metadata for doc in docs] |
|
all_sources = [m["source"] for m in metadatas] |
|
|
|
|
|
source_elements = [] |
|
if sources: |
|
found_sources = [] |
|
|
|
|
|
for source in sources.split(","): |
|
source_name = source.strip().replace(".", "") |
|
|
|
try: |
|
index = all_sources.index(source_name) |
|
except ValueError: |
|
continue |
|
text = docs[index].page_content |
|
found_sources.append(source_name) |
|
|
|
source_elements.append(cl.Text(content=text, name=source_name)) |
|
|
|
if found_sources: |
|
answer += f"\nSources: {', '.join(found_sources)}" |
|
else: |
|
answer += "\nNo sources found" |
|
|
|
await cl.Message(content=answer, elements=source_elements).send() |
|
|