CCCDev commited on
Commit
8ddfd2d
·
verified ·
1 Parent(s): b44fef3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from langchain_community.llms import HuggingFacePipeline
10
+ from langchain.chains import ConversationChain
11
+ from langchain.memory import ConversationBufferMemory
12
+ from langchain_community.llms import HuggingFaceEndpoint
13
+
14
+ from pathlib import Path
15
+ import chromadb
16
+ from unidecode import unidecode
17
+
18
+ from transformers import AutoTokenizer
19
+ import transformers
20
+ import torch
21
+ import tqdm
22
+ import accelerate
23
+ import re
24
+
25
+ # Static PDF file link
26
+ static_pdf_link = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Data-privacy-policy.pdf"
27
+
28
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
29
+ "mistralai/Mistral-7B-Instruct-v0.1", "google/gemma-7b-it", "google/gemma-2b-it",
30
+ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
31
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
32
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
33
+ "google/flan-t5-xxl"]
34
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
35
+
36
+
37
+ # Load PDF document and create doc splits
38
+ def load_doc(file_path, chunk_size, chunk_overlap):
39
+ loader = PyPDFLoader(file_path)
40
+ pages = loader.load()
41
+ text_splitter = RecursiveCharacterTextSplitter(
42
+ chunk_size=chunk_size,
43
+ chunk_overlap=chunk_overlap)
44
+ doc_splits = text_splitter.split_documents(pages)
45
+ return doc_splits
46
+
47
+
48
+ # Create vector database
49
+ def create_db(splits, collection_name):
50
+ embedding = HuggingFaceEmbeddings()
51
+ new_client = chromadb.EphemeralClient()
52
+ vectordb = Chroma.from_documents(
53
+ documents=splits,
54
+ embedding=embedding,
55
+ client=new_client,
56
+ collection_name=collection_name,
57
+ )
58
+ return vectordb
59
+
60
+
61
+ # Initialize langchain LLM chain
62
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
63
+ progress(0.1, desc="Initializing HF tokenizer...")
64
+
65
+ progress(0.5, desc="Initializing HF Hub...")
66
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
67
+ llm = HuggingFaceEndpoint(
68
+ repo_id=llm_model,
69
+ temperature=temperature,
70
+ max_new_tokens=max_tokens,
71
+ top_k=top_k,
72
+ load_in_8bit=True,
73
+ )
74
+ elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
75
+ raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
76
+ elif llm_model == "microsoft/phi-2":
77
+ llm = HuggingFaceEndpoint(
78
+ repo_id=llm_model,
79
+ temperature=temperature,
80
+ max_new_tokens=max_tokens,
81
+ top_k=top_k,
82
+ trust_remote_code=True,
83
+ torch_dtype="auto",
84
+ )
85
+ elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
86
+ llm = HuggingFaceEndpoint(
87
+ repo_id=llm_model,
88
+ temperature=temperature,
89
+ max_new_tokens=250,
90
+ top_k=top_k,
91
+ )
92
+ elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
93
+ raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
94
+ else:
95
+ llm = HuggingFaceEndpoint(
96
+ repo_id=llm_model,
97
+ temperature=temperature,
98
+ max_new_tokens=max_tokens,
99
+ top_k=top_k,
100
+ )
101
+
102
+ progress(0.75, desc="Defining buffer memory...")
103
+ memory = ConversationBufferMemory(
104
+ memory_key="chat_history",
105
+ output_key='answer',
106
+ return_messages=True
107
+ )
108
+ retriever = vector_db.as_retriever()
109
+ progress(0.8, desc="Defining retrieval chain...")
110
+ qa_chain = ConversationalRetrievalChain.from_llm(
111
+ llm,
112
+ retriever=retriever,
113
+ chain_type="stuff",
114
+ memory=memory,
115
+ return_source_documents=True,
116
+ verbose=False,
117
+ )
118
+ progress(0.9, desc="Done!")
119
+ return qa_chain
120
+
121
+
122
+ # Generate collection name for vector database
123
+ def create_collection_name(filepath):
124
+ collection_name = Path(filepath).stem
125
+ collection_name = collection_name.replace(" ", "-")
126
+ collection_name = unidecode(collection_name)
127
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
128
+ collection_name = collection_name[:50]
129
+ if len(collection_name) < 3:
130
+ collection_name = collection_name + 'xyz'
131
+ if not collection_name[0].isalnum():
132
+ collection_name = 'A' + collection_name[1:]
133
+ if not collection_name[-1].isalnum():
134
+ collection_name = collection_name[:-1] + 'Z'
135
+ print('Filepath: ', filepath)
136
+ print('Collection name: ', collection_name)
137
+ return collection_name
138
+
139
+
140
+ # Initialize database
141
+ def initialize_database(chunk_size, chunk_overlap, progress=gr.Progress()):
142
+ file_path = static_pdf_link
143
+ progress(0.1, desc="Creating collection name...")
144
+ collection_name = create_collection_name(file_path)
145
+ progress(0.25, desc="Loading document...")
146
+ doc_splits = load_doc(file_path, chunk_size, chunk_overlap)
147
+ progress(0.5, desc="Generating vector database...")
148
+ vector_db = create_db(doc_splits, collection_name)
149
+ progress(0.9, desc="Done!")
150
+ return vector_db, collection_name, "Complete!"
151
+
152
+
153
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
154
+ llm_name = list_llm[llm_option]
155
+ print("llm_name: ", llm_name)
156
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
157
+ return qa_chain, "Complete!"
158
+
159
+
160
+ def format_chat_history(message, chat_history):
161
+ formatted_chat_history = []
162
+ for user_message, bot_message in chat_history:
163
+ formatted_chat_history.append(f"User: {user_message}")
164
+ formatted_chat_history.append(f"Assistant: {bot_message}")
165
+ return formatted_chat_history
166
+
167
+
168
+ def conversation(qa_chain, message, history):
169
+ formatted_chat_history = format_chat_history(message, history)
170
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
171
+ response_answer = response["answer"]
172
+ if response_answer.find("Helpful Answer:") != -1:
173
+ response_answer = response_answer.split("Helpful Answer:")[-1]
174
+ response_sources = response["source_documents"]
175
+ response_source1 = response_sources[0].page_content.strip()
176
+ response_source2 = response_sources[1].page_content.strip()
177
+ response_source3 = response_sources[2].page_content.strip()
178
+ response_source1_page = response_sources[0].metadata["page"] + 1
179
+ response_source2_page = response_sources[1].metadata["page"] + 1
180
+ response_source3_page = response_sources[2].metadata["page"] + 1
181
+
182
+ new_history = history + [(message, response_answer)]
183
+ return qa_chain, gr.update(
184
+ value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
185
+
186
+
187
+ def demo():
188
+ with gr.Blocks(theme="base") as demo:
189
+ vector_db = gr.State()
190
+ qa_chain = gr.State()
191
+ collection_name = gr.State()
192
+
193
+ gr.Markdown(
194
+ """<center><h2>PDF-based chatbot</center></h2>
195
+ <h3>Ask any questions about your PDF documents</h3>""")
196
+ gr.Markdown(
197
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
198
+ The user interface explicitely shows multiple steps to help understand the RAG workflow.
199
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
200
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
201
+ """)
202
+
203
+ with gr.Tab("Step 2 - Process document"):
204
+ with gr.Row():
205
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index",
206
+ info="Choose your vector database")
207
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
208
+ with gr.Row():
209
+ chunk_size = gr.Slider(64, 4096, value=512, step=32, label="Text chunk size",
210
+ info="Text length of each document chunk being embedded into the vector database. Default is 512.")
211
+ chunk_overlap = gr.Slider(0, 1024, value=24, step=8, label="Text chunk overlap",
212
+ info="Text overlap between each document chunk being embedded into the vector database. Default is 24.")
213
+
214
+ initialize_db = gr.Button("Process document")
215
+
216
+ with gr.Row():
217
+ output_db = gr.Textbox(label="Database initialization steps", placeholder="", show_label=False)
218
+ with gr.Accordion("Vector database collection details", open=False):
219
+ collection = gr.Textbox(label="Collection name", placeholder="", show_label=False)
220
+
221
+ with gr.Tab("Step 3 - Initialize LLM"):
222
+ with gr.Row():
223
+ llm_options = gr.Dropdown(list_llm_simple, label="Choose open-source LLM",
224
+ value="Mistral-7B-Instruct-v0.2",
225
+ info="Choose among the proposed open-source LLMs")
226
+ with gr.Accordion("Advanced LLM options", open=False):
227
+ with gr.Row():
228
+ llm_temperature = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="LLM temperature",
229
+ info="LLM sampling temperature, in [0.01,1.0] range. Default is 0.1")
230
+ llm_max_tokens = gr.Slider(32, 1024, value=512, step=16, label="Max tokens",
231
+ info="Maximum number of new tokens to be generated, in [32,1024] range. Default is 512")
232
+ llm_top_k = gr.Slider(1, 40, value=20, step=1, label="Top K",
233
+ info="The number of highest probability vocabulary tokens to keep for top-k-filtering. Default is 20.")
234
+
235
+ initialize_llm = gr.Button("Initialize LLM")
236
+
237
+ with gr.Row():
238
+ output_llm = gr.Textbox(label="LLM initialization steps", placeholder="", show_label=False)
239
+
240
+ with gr.Tab("Step 4 - Start chatting"):
241
+ chatbot = gr.Chatbot(label="PDF chatbot", height=500)
242
+ msg = gr.Textbox(label="Your question", placeholder="Type your question here...", show_label=False)
243
+ clear = gr.Button("Clear chat")
244
+
245
+ with gr.Accordion("Document sources (3)", open=False):
246
+ gr.Markdown("Source 1")
247
+ response_src1 = gr.Textbox(label="Source 1", placeholder="", show_label=False)
248
+ response_src1_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
249
+ gr.Markdown("Source 2")
250
+ response_src2 = gr.Textbox(label="Source 2", placeholder="", show_label=False)
251
+ response_src2_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
252
+ gr.Markdown("Source 3")
253
+ response_src3 = gr.Textbox(label="Source 3", placeholder="", show_label=False)
254
+ response_src3_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
255
+
256
+ initialize_db.click(initialize_database,
257
+ inputs=[chunk_size, chunk_overlap],
258
+ outputs=[vector_db, collection_name, output_db])
259
+ initialize_llm.click(initialize_LLM,
260
+ inputs=[llm_options, llm_temperature, llm_max_tokens, llm_top_k, vector_db],
261
+ outputs=[qa_chain, output_llm])
262
+ msg.submit(conversation,
263
+ inputs=[qa_chain, msg, chatbot],
264
+ outputs=[chatbot, msg, chatbot, response_src1, response_src1_page, response_src2, response_src2_page,
265
+ response_src3, response_src3_page])
266
+ clear.click(lambda: None, None, chatbot, queue=False)
267
+ clear.click(lambda: None, None, msg, queue=False)
268
+
269
+ return demo.queue().launch(debug=True)
270
+
271
+
272
+ # demo().launch(server_name="0.0.0.0")
273
+ if __name__ == "__main__":
274
+ demo()