CCCDev commited on
Commit
7d8dbca
·
verified ·
1 Parent(s): 8ddfd2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -121
app.py CHANGED
@@ -22,25 +22,16 @@ import tqdm
22
  import accelerate
23
  import re
24
 
25
- # Static PDF file link
26
- static_pdf_link = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Data-privacy-policy.pdf"
27
-
28
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1",
29
- "mistralai/Mistral-7B-Instruct-v0.1", "google/gemma-7b-it", "google/gemma-2b-it",
30
- "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
31
- "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2",
32
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct",
33
- "google/flan-t5-xxl"]
34
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
 
35
 
36
 
37
  # Load PDF document and create doc splits
38
- def load_doc(file_path, chunk_size, chunk_overlap):
39
- loader = PyPDFLoader(file_path)
40
  pages = loader.load()
41
- text_splitter = RecursiveCharacterTextSplitter(
42
- chunk_size=chunk_size,
43
- chunk_overlap=chunk_overlap)
44
  doc_splits = text_splitter.split_documents(pages)
45
  return doc_splits
46
 
@@ -60,44 +51,13 @@ def create_db(splits, collection_name):
60
 
61
  # Initialize langchain LLM chain
62
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
63
- progress(0.1, desc="Initializing HF tokenizer...")
64
-
65
  progress(0.5, desc="Initializing HF Hub...")
66
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
67
- llm = HuggingFaceEndpoint(
68
- repo_id=llm_model,
69
- temperature=temperature,
70
- max_new_tokens=max_tokens,
71
- top_k=top_k,
72
- load_in_8bit=True,
73
- )
74
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
75
- raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint")
76
- elif llm_model == "microsoft/phi-2":
77
- llm = HuggingFaceEndpoint(
78
- repo_id=llm_model,
79
- temperature=temperature,
80
- max_new_tokens=max_tokens,
81
- top_k=top_k,
82
- trust_remote_code=True,
83
- torch_dtype="auto",
84
- )
85
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
86
- llm = HuggingFaceEndpoint(
87
- repo_id=llm_model,
88
- temperature=temperature,
89
- max_new_tokens=250,
90
- top_k=top_k,
91
- )
92
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
93
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
94
- else:
95
- llm = HuggingFaceEndpoint(
96
- repo_id=llm_model,
97
- temperature=temperature,
98
- max_new_tokens=max_tokens,
99
- top_k=top_k,
100
- )
101
 
102
  progress(0.75, desc="Defining buffer memory...")
103
  memory = ConversationBufferMemory(
@@ -132,18 +92,14 @@ def create_collection_name(filepath):
132
  collection_name = 'A' + collection_name[1:]
133
  if not collection_name[-1].isalnum():
134
  collection_name = collection_name[:-1] + 'Z'
135
- print('Filepath: ', filepath)
136
- print('Collection name: ', collection_name)
137
  return collection_name
138
 
139
 
140
  # Initialize database
141
- def initialize_database(chunk_size, chunk_overlap, progress=gr.Progress()):
142
- file_path = static_pdf_link
143
- progress(0.1, desc="Creating collection name...")
144
- collection_name = create_collection_name(file_path)
145
  progress(0.25, desc="Loading document...")
146
- doc_splits = load_doc(file_path, chunk_size, chunk_overlap)
147
  progress(0.5, desc="Generating vector database...")
148
  vector_db = create_db(doc_splits, collection_name)
149
  progress(0.9, desc="Done!")
@@ -152,7 +108,6 @@ def initialize_database(chunk_size, chunk_overlap, progress=gr.Progress()):
152
 
153
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
154
  llm_name = list_llm[llm_option]
155
- print("llm_name: ", llm_name)
156
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
157
  return qa_chain, "Complete!"
158
 
@@ -178,7 +133,6 @@ def conversation(qa_chain, message, history):
178
  response_source1_page = response_sources[0].metadata["page"] + 1
179
  response_source2_page = response_sources[1].metadata["page"] + 1
180
  response_source3_page = response_sources[2].metadata["page"] + 1
181
-
182
  new_history = history + [(message, response_answer)]
183
  return qa_chain, gr.update(
184
  value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
@@ -200,72 +154,57 @@ def demo():
200
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
201
  """)
202
 
203
- with gr.Tab("Step 2 - Process document"):
204
- with gr.Row():
205
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value="ChromaDB", type="index",
206
- info="Choose your vector database")
207
- with gr.Accordion("Advanced options - Document text splitter", open=False):
208
- with gr.Row():
209
- chunk_size = gr.Slider(64, 4096, value=512, step=32, label="Text chunk size",
210
- info="Text length of each document chunk being embedded into the vector database. Default is 512.")
211
- chunk_overlap = gr.Slider(0, 1024, value=24, step=8, label="Text chunk overlap",
212
- info="Text overlap between each document chunk being embedded into the vector database. Default is 24.")
213
 
214
- initialize_db = gr.Button("Process document")
 
215
 
216
- with gr.Row():
217
- output_db = gr.Textbox(label="Database initialization steps", placeholder="", show_label=False)
218
- with gr.Accordion("Vector database collection details", open=False):
219
- collection = gr.Textbox(label="Collection name", placeholder="", show_label=False)
220
 
221
- with gr.Tab("Step 3 - Initialize LLM"):
222
- with gr.Row():
223
- llm_options = gr.Dropdown(list_llm_simple, label="Choose open-source LLM",
224
- value="Mistral-7B-Instruct-v0.2",
225
- info="Choose among the proposed open-source LLMs")
226
- with gr.Accordion("Advanced LLM options", open=False):
227
  with gr.Row():
228
- llm_temperature = gr.Slider(0.01, 1.0, value=0.1, step=0.01, label="LLM temperature",
229
- info="LLM sampling temperature, in [0.01,1.0] range. Default is 0.1")
230
- llm_max_tokens = gr.Slider(32, 1024, value=512, step=16, label="Max tokens",
231
- info="Maximum number of new tokens to be generated, in [32,1024] range. Default is 512")
232
- llm_top_k = gr.Slider(1, 40, value=20, step=1, label="Top K",
233
- info="The number of highest probability vocabulary tokens to keep for top-k-filtering. Default is 20.")
234
-
235
- initialize_llm = gr.Button("Initialize LLM")
236
-
237
  with gr.Row():
238
- output_llm = gr.Textbox(label="LLM initialization steps", placeholder="", show_label=False)
239
-
240
- with gr.Tab("Step 4 - Start chatting"):
241
- chatbot = gr.Chatbot(label="PDF chatbot", height=500)
242
- msg = gr.Textbox(label="Your question", placeholder="Type your question here...", show_label=False)
243
- clear = gr.Button("Clear chat")
244
-
245
- with gr.Accordion("Document sources (3)", open=False):
246
- gr.Markdown("Source 1")
247
- response_src1 = gr.Textbox(label="Source 1", placeholder="", show_label=False)
248
- response_src1_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
249
- gr.Markdown("Source 2")
250
- response_src2 = gr.Textbox(label="Source 2", placeholder="", show_label=False)
251
- response_src2_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
252
- gr.Markdown("Source 3")
253
- response_src3 = gr.Textbox(label="Source 3", placeholder="", show_label=False)
254
- response_src3_page = gr.Number(label="Page number", value=0, precision=0, interactive=False)
255
-
256
- initialize_db.click(initialize_database,
257
- inputs=[chunk_size, chunk_overlap],
258
- outputs=[vector_db, collection_name, output_db])
259
- initialize_llm.click(initialize_LLM,
260
- inputs=[llm_options, llm_temperature, llm_max_tokens, llm_top_k, vector_db],
261
- outputs=[qa_chain, output_llm])
262
- msg.submit(conversation,
263
- inputs=[qa_chain, msg, chatbot],
264
- outputs=[chatbot, msg, chatbot, response_src1, response_src1_page, response_src2, response_src2_page,
265
- response_src3, response_src3_page])
266
- clear.click(lambda: None, None, chatbot, queue=False)
267
- clear.click(lambda: None, None, msg, queue=False)
268
-
269
  return demo.queue().launch(debug=True)
270
 
271
 
 
22
  import accelerate
23
  import re
24
 
25
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2"]
 
 
 
 
 
 
 
 
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
+ pdf_url = "path/to/your/static.pdf" # Replace with your static PDF URL or path
28
 
29
 
30
  # Load PDF document and create doc splits
31
+ def load_doc(pdf_url, chunk_size, chunk_overlap):
32
+ loader = PyPDFLoader(pdf_url)
33
  pages = loader.load()
34
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
 
 
35
  doc_splits = text_splitter.split_documents(pages)
36
  return doc_splits
37
 
 
51
 
52
  # Initialize langchain LLM chain
53
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
 
54
  progress(0.5, desc="Initializing HF Hub...")
55
+ llm = HuggingFaceEndpoint(
56
+ repo_id=llm_model,
57
+ temperature=temperature,
58
+ max_new_tokens=max_tokens,
59
+ top_k=top_k,
60
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  progress(0.75, desc="Defining buffer memory...")
63
  memory = ConversationBufferMemory(
 
92
  collection_name = 'A' + collection_name[1:]
93
  if not collection_name[-1].isalnum():
94
  collection_name = collection_name[:-1] + 'Z'
 
 
95
  return collection_name
96
 
97
 
98
  # Initialize database
99
+ def initialize_database(pdf_url, chunk_size, chunk_overlap, progress=gr.Progress()):
100
+ collection_name = create_collection_name(pdf_url)
 
 
101
  progress(0.25, desc="Loading document...")
102
+ doc_splits = load_doc(pdf_url, chunk_size, chunk_overlap)
103
  progress(0.5, desc="Generating vector database...")
104
  vector_db = create_db(doc_splits, collection_name)
105
  progress(0.9, desc="Done!")
 
108
 
109
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
110
  llm_name = list_llm[llm_option]
 
111
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
112
  return qa_chain, "Complete!"
113
 
 
133
  response_source1_page = response_sources[0].metadata["page"] + 1
134
  response_source2_page = response_sources[1].metadata["page"] + 1
135
  response_source3_page = response_sources[2].metadata["page"] + 1
 
136
  new_history = history + [(message, response_answer)]
137
  return qa_chain, gr.update(
138
  value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
154
  <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
155
  """)
156
 
157
+ with gr.Tab("Step 1 - Upload PDF"):
158
+ gr.Markdown("Using static PDF link: path/to/your/static.pdf")
 
 
 
 
 
 
 
 
159
 
160
+ with gr.Tab("Step 2 - Process document"):
161
+ gr.Markdown("Processing document automatically...")
162
 
163
+ with gr.Tab("Step 3 - Initialize QA chain"):
164
+ gr.Markdown("Initializing QA chain automatically...")
 
 
165
 
166
+ with gr.Tab("Step 4 - Chatbot"):
167
+ chatbot = gr.Chatbot(height=300)
168
+ with gr.Accordion("Advanced - Document references", open=False):
 
 
 
169
  with gr.Row():
170
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
171
+ source1_page = gr.Number(label="Page", scale=1)
172
+ with gr.Row():
173
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
174
+ source2_page = gr.Number(label="Page", scale=1)
175
+ with gr.Row():
176
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
177
+ source3_page = gr.Number(label="Page", scale=1)
 
178
  with gr.Row():
179
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
180
+ with gr.Row():
181
+ submit_btn = gr.Button("Submit message")
182
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
183
+
184
+ # Automatic preprocessing
185
+ db_progress = gr.Textbox(label="Vector database initialization", value="Initializing...")
186
+ db_btn = gr.Button("Generate vector database", visible=False)
187
+ qachain_btn = gr.Button("Initialize Question Answering chain", visible=False)
188
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
189
+
190
+ def auto_initialize():
191
+ vector_db, collection_name, db_status = initialize_database(pdf_url, 512, 24)
192
+ qa_chain, llm_status = initialize_LLM(0, 0.1, 512, 20, vector_db)
193
+ return vector_db, collection_name, db_status, qa_chain, llm_status, "Initialization complete."
194
+
195
+ demo.load(auto_initialize, [], [vector_db, collection_name, db_progress, qa_chain, llm_progress])
196
+
197
+ # Chatbot events
198
+ msg.submit(conversation, \
199
+ inputs=[qa_chain, msg, chatbot], \
200
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3,
201
+ source3_page], \
202
+ queue=False)
203
+ submit_btn.click(conversation, \
204
+ inputs=[qa_chain, msg, chatbot], \
205
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page,
206
+ doc_source3, source3_page], \
207
+ queue=False)
 
 
208
  return demo.queue().launch(debug=True)
209
 
210