Spaces:
Runtime error
Runtime error
File size: 7,338 Bytes
3ec9224 5be8df6 85a72e5 3ec9224 5db4902 5be8df6 5db4902 5be8df6 85a72e5 ecf1633 1ef8d7c aa98840 1ef8d7c e28718b 37ae113 e28718b 5be8df6 e28718b 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 1ef8d7c 5be8df6 e28718b 5be8df6 9733941 5be8df6 e28718b 5be8df6 9733941 138ca2e 5be8df6 00bd139 5be8df6 9bf736d e28718b 08108c1 fa7cc51 6e8daa8 fa7cc51 6e8daa8 9bf736d 5be8df6 1ef8d7c 5be8df6 e28718b 5be8df6 e28718b 00bd139 5be8df6 e28718b 5be8df6 e28718b 00bd139 5be8df6 e28718b 9733941 e28718b 85a72e5 5be8df6 e28718b 5be8df6 e28718b 3ca2785 00bd139 1ef8d7c 5be8df6 e28718b 6f396af 51d2a09 e28718b 51d2a09 e28718b 5be8df6 e28718b 14155e5 e28718b 5be8df6 e28718b 85a72e5 5be8df6 e28718b 5be8df6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import os
import re
from pathlib import Path
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
import chromadb
from unidecode import unidecode
# List of allowed models
allowed_llms = [
"mistralai/Mistral-7B-Instruct-v0.2",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.1",
"google/gemma-7b-it",
"google/gemma-2b-it",
"HuggingFaceH4/zephyr-7b-beta",
"HuggingFaceH4/zephyr-7b-gemma-v0.1",
"meta-llama/Llama-2-7b-chat-hf"
]
list_llm_simple = [os.path.basename(llm) for llm in allowed_llms]
# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
load_in_8bit=True,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
# Generate collection name for vector database
def create_collection_name(filepath):
collection_name = Path(filepath).stem
collection_name = unidecode(collection_name).replace(" ", "-")
collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
if len(collection_name) < 3:
collection_name = collection_name + 'xyz'
if not collection_name[0].isalnum():
collection_name = 'A' + collection_name[1:]
if not collection_name[-1].isalnum():
collection_name = collection_name[:-1] + 'Z'
return collection_name
# Initialize database
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
list_file_path = [x.name for x in list_file_obj if x is not None]
collection_name = create_collection_name(list_file_path[0])
doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
vector_db = create_db(doc_splits, collection_name)
return vector_db, collection_name, "Complete!"
# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
llm_name = allowed_llms[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Complete!"
# Format chat history
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
# Conversation handling
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"].split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
new_history = history + [(message, response_answer)]
response_details = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
return gr.update(value=""), new_history, *sum(response_details, ())
# Gradio Interface
def demo():
with gr.Blocks(theme="default") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown(
"""<center><h2>PDF-based Chatbot</h2></center>
<h3>Ask any questions about your PDF documents</h3>""")
with gr.Tab("Upload PDF"):
document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF Documents")
with gr.Tab("Process Document"):
db_btn = gr.Radio(["ChromaDB"], label="Vector Database", value="ChromaDB", type="index")
with gr.Accordion("Advanced Options", open=False):
slider_chunk_size = gr.Slider(100, 1000, 600, 20, label="Chunk Size", interactive=True)
slider_chunk_overlap = gr.Slider(10, 200, 40, 10, label="Chunk Overlap", interactive=True)
db_progress = gr.Textbox(label="Database Initialization Status", value="None")
db_btn = gr.Button("Generate Database")
with gr.Tab("Initialize QA Chain"):
llm_btn = gr.Radio(list_llm_simple, label="LLM Models", value=list_llm_simple[0], type="index")
with gr.Accordion("Advanced Options", open=False):
slider_temperature = gr.Slider(0.01, 1.0, 0.7, 0.1, label="Temperature", interactive=True)
slider_maxtokens = gr.Slider(224, 4096, 1024, 32, label="Max Tokens", interactive=True)
slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k Samples", interactive=True)
llm_progress = gr.Textbox(value="None", label="QA Chain Initialization Status")
qachain_btn = gr.Button("Initialize QA Chain")
with gr.Tab("Chatbot"):
chatbot = gr.Chatbot(height=300)
with gr.Accordion("Document References", open=False):
for i in range(1, 4):
gr.Row([gr.Textbox(label=f"Reference {i}", lines=2, container=True, scale=20), gr.Number(label="Page", scale=1)])
msg = gr.Textbox(placeholder="Type message here...", container=True)
gr.Row([gr.Button("Submit"), gr.Button("Clear Conversation")])
# Define Interactions
db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot] + [None] * 6)
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|