File size: 7,338 Bytes
3ec9224
5be8df6
85a72e5
 
3ec9224
5db4902
5be8df6
5db4902
5be8df6
85a72e5
ecf1633
1ef8d7c
aa98840
1ef8d7c
e28718b
 
 
 
 
 
 
 
 
 
37ae113
e28718b
5be8df6
 
 
 
 
 
 
 
e28718b
 
 
5be8df6
 
 
 
1ef8d7c
5be8df6
1ef8d7c
5be8df6
 
 
1ef8d7c
 
5be8df6
 
 
 
 
e28718b
 
 
 
 
 
 
5be8df6
 
 
9733941
5be8df6
 
e28718b
 
5be8df6
 
 
 
 
9733941
138ca2e
5be8df6
00bd139
5be8df6
9bf736d
 
 
e28718b
 
08108c1
 
fa7cc51
6e8daa8
fa7cc51
6e8daa8
9bf736d
 
 
 
 
 
5be8df6
1ef8d7c
 
5be8df6
e28718b
5be8df6
e28718b
00bd139
 
5be8df6
e28718b
5be8df6
 
 
 
 
 
 
e28718b
00bd139
5be8df6
 
e28718b
9733941
 
e28718b
85a72e5
5be8df6
e28718b
5be8df6
e28718b
3ca2785
00bd139
1ef8d7c
5be8df6
 
e28718b
6f396af
51d2a09
e28718b
 
51d2a09
e28718b
 
 
 
 
 
 
5be8df6
e28718b
 
 
 
 
 
 
 
 
 
14155e5
e28718b
 
 
 
 
5be8df6
e28718b
 
 
85a72e5
5be8df6
e28718b
5be8df6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import os
import re
from pathlib import Path

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint
import chromadb
from unidecode import unidecode

# List of allowed models
allowed_llms = [
    "mistralai/Mistral-7B-Instruct-v0.2", 
    "mistralai/Mixtral-8x7B-Instruct-v0.1", 
    "mistralai/Mistral-7B-Instruct-v0.1",
    "google/gemma-7b-it", 
    "google/gemma-2b-it", 
    "HuggingFaceH4/zephyr-7b-beta", 
    "HuggingFaceH4/zephyr-7b-gemma-v0.1", 
    "meta-llama/Llama-2-7b-chat-hf"
]
list_llm_simple = [os.path.basename(llm) for llm in allowed_llms]

# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size, 
        chunk_overlap=chunk_overlap
    )
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# Create vector database
def create_db(splits, collection_name):
    embedding = HuggingFaceEmbeddings()
    new_client = chromadb.EphemeralClient()
    vectordb = Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        client=new_client,
        collection_name=collection_name,
    )
    return vectordb

# Initialize langchain LLM chain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm = HuggingFaceEndpoint(
        repo_id=llm_model, 
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
        load_in_8bit=True,
    )
    
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    retriever = vector_db.as_retriever()
    
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    return qa_chain

# Generate collection name for vector database
def create_collection_name(filepath):
    collection_name = Path(filepath).stem
    collection_name = unidecode(collection_name).replace(" ", "-")
    collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
    if len(collection_name) < 3:
        collection_name = collection_name + 'xyz'
    if not collection_name[0].isalnum():
        collection_name = 'A' + collection_name[1:]
    if not collection_name[-1].isalnum():
        collection_name = collection_name[:-1] + 'Z'
    return collection_name

# Initialize database
def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    collection_name = create_collection_name(list_file_path[0])
    doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
    vector_db = create_db(doc_splits, collection_name)
    return vector_db, collection_name, "Complete!"

# Initialize LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
    llm_name = allowed_llms[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
    return qa_chain, "Complete!"

# Format chat history
def format_chat_history(message, chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

# Conversation handling
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"].split("Helpful Answer:")[-1]
    response_sources = response["source_documents"]
    new_history = history + [(message, response_answer)]
    response_details = [(src.page_content.strip(), src.metadata["page"] + 1) for src in response_sources[:3]]
    return gr.update(value=""), new_history, *sum(response_details, ())

# Gradio Interface
def demo():
    with gr.Blocks(theme="default") as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        collection_name = gr.State()
        
        gr.Markdown(
        """<center><h2>PDF-based Chatbot</h2></center>
        <h3>Ask any questions about your PDF documents</h3>""")
        
        with gr.Tab("Upload PDF"):
            document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF Documents")
        
        with gr.Tab("Process Document"):
            db_btn = gr.Radio(["ChromaDB"], label="Vector Database", value="ChromaDB", type="index")
            with gr.Accordion("Advanced Options", open=False):
                slider_chunk_size = gr.Slider(100, 1000, 600, 20, label="Chunk Size", interactive=True)
                slider_chunk_overlap = gr.Slider(10, 200, 40, 10, label="Chunk Overlap", interactive=True)
            db_progress = gr.Textbox(label="Database Initialization Status", value="None")
            db_btn = gr.Button("Generate Database")
            
        with gr.Tab("Initialize QA Chain"):
            llm_btn = gr.Radio(list_llm_simple, label="LLM Models", value=list_llm_simple[0], type="index")
            with gr.Accordion("Advanced Options", open=False):
                slider_temperature = gr.Slider(0.01, 1.0, 0.7, 0.1, label="Temperature", interactive=True)
                slider_maxtokens = gr.Slider(224, 4096, 1024, 32, label="Max Tokens", interactive=True)
                slider_topk = gr.Slider(1, 10, 3, 1, label="Top-k Samples", interactive=True)
            llm_progress = gr.Textbox(value="None", label="QA Chain Initialization Status")
            qachain_btn = gr.Button("Initialize QA Chain")

        with gr.Tab("Chatbot"):
            chatbot = gr.Chatbot(height=300)
            with gr.Accordion("Document References", open=False):
                for i in range(1, 4):
                    gr.Row([gr.Textbox(label=f"Reference {i}", lines=2, container=True, scale=20), gr.Number(label="Page", scale=1)])
            msg = gr.Textbox(placeholder="Type message here...", container=True)
            gr.Row([gr.Button("Submit"), gr.Button("Clear Conversation")])
            
        # Define Interactions
        db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
        qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain, llm_progress])
        msg.submit(conversation, inputs=[qa_chain, msg, chatbot], outputs=[msg, chatbot] + [None] * 6)

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()