Shreyas094 commited on
Commit
487fdcd
·
verified ·
1 Parent(s): 5999644

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -111
app.py CHANGED
@@ -20,8 +20,9 @@ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
20
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
21
 
22
  MODELS = [
23
- "mistralai/Mistral-7B-Instruct-v0.3",
24
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
 
25
  "microsoft/Phi-3-mini-4k-instruct"
26
  ]
27
 
@@ -77,53 +78,76 @@ def update_vectors(files, parser):
77
 
78
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
79
 
80
- def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=5, temperature=0.2, should_stop=False):
81
  print(f"Starting generate_chunked_response with {num_calls} calls")
82
  client = InferenceClient(model, token=huggingface_token)
83
- full_response = ""
84
  messages = [{"role": "user", "content": prompt}]
85
 
86
  for i in range(num_calls):
87
  print(f"Starting API call {i+1}")
88
- if should_stop:
89
  print("Stop clicked, breaking loop")
90
  break
91
  try:
 
92
  for message in client.chat_completion(
93
  messages=messages,
94
  max_tokens=max_tokens,
95
  temperature=temperature,
96
  stream=True,
97
  ):
98
- if should_stop:
99
  print("Stop clicked during streaming, breaking")
100
  break
101
  if message.choices and message.choices[0].delta and message.choices[0].delta.content:
102
  chunk = message.choices[0].delta.content
103
- full_response += chunk
104
- print(f"API call {i+1} completed")
 
105
  except Exception as e:
106
  print(f"Error in generating response: {str(e)}")
107
 
108
- # Clean up the response
109
- clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
 
110
  clean_response = clean_response.replace("Using the following context:", "").strip()
111
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
112
 
113
- # Remove duplicate paragraphs and sentences
114
- paragraphs = clean_response.split('\n\n')
 
 
 
 
 
115
  unique_paragraphs = []
116
  for paragraph in paragraphs:
117
  if paragraph not in unique_paragraphs:
118
- sentences = paragraph.split('. ')
119
  unique_sentences = []
 
120
  for sentence in sentences:
121
  if sentence not in unique_sentences:
122
  unique_sentences.append(sentence)
123
  unique_paragraphs.append('. '.join(unique_sentences))
124
 
125
- final_response = '\n\n'.join(unique_paragraphs)
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  print(f"Final clean response: {final_response[:100]}...")
128
  return final_response
129
 
@@ -137,104 +161,82 @@ class CitingSources(BaseModel):
137
  ...,
138
  description="List of sources to cite. Should be an URL of the source."
139
  )
140
- def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
141
- if not message.strip():
142
- return "", history
143
-
144
- history = history + [(message, "")]
145
-
146
- try:
147
- if use_web_search:
148
- for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
149
- history[-1] = (message, f"{main_content}\n\n{sources}")
150
- yield history
151
- else:
152
- for partial_response in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
153
- history[-1] = (message, partial_response)
154
- yield history
155
- except gr.CancelledError:
156
- yield history
157
 
158
- def retry_last_response(history, use_web_search, model, temperature, num_calls):
159
- if not history:
160
- return history
161
-
162
- last_user_msg = history[-1][0]
163
- history = history[:-1] # Remove the last response
164
-
165
- return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
166
-
167
- def respond(message, history, model, temperature, num_calls, use_web_search):
168
- if use_web_search:
169
- for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
170
- yield f"{main_content}\n\n{sources}"
171
- else:
172
- for partial_response, _ in get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature):
173
- yield partial_response
174
-
175
- def get_response_with_search(query, model, num_calls=5, temperature=0.2):
176
  search_results = duckduckgo_search(query)
177
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
178
  for result in search_results if 'body' in result)
179
 
180
- prompt = f"""Using the following context:
181
  {context}
182
  Write a detailed and complete research document that fulfills the following user request: '{query}'
183
- After writing the document, please provide a list of sources used in your response."""
184
 
185
- client = InferenceClient(model, token=huggingface_token)
 
 
 
 
186
 
187
- main_content = ""
188
- for i in range(num_calls):
189
- for message in client.chat_completion(
190
- messages=[{"role": "user", "content": prompt}],
191
- max_tokens=1000,
192
- temperature=temperature,
193
- stream=True,
194
- ):
195
- if message.choices and message.choices[0].delta and message.choices[0].delta.content:
196
- chunk = message.choices[0].delta.content
197
- main_content += chunk
198
- yield main_content, "" # Yield partial main content without sources
199
 
200
- def get_response_from_pdf(query, model, num_calls=5, temperature=0.2):
201
  embed = get_embeddings()
202
  if os.path.exists("faiss_database"):
203
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
204
  else:
205
- yield "No documents available. Please upload PDF documents to answer questions."
206
- return
207
 
208
  retriever = database.as_retriever()
209
  relevant_docs = retriever.get_relevant_documents(query)
210
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
211
 
212
- prompt = f"""Using the following context from the PDF documents:
213
  {context_str}
214
- Write a detailed and complete response that fully answers the following user question.
215
- Ensure your response covers all relevant information and is not cut off: '{query}'
216
- If the response is long, please continue until you have provided a comprehensive answer."""
217
 
218
- client = InferenceClient(model, token=huggingface_token)
219
-
220
- response = ""
221
- for i in range(num_calls):
222
- for message in client.chat_completion(
223
- messages=[{"role": "user", "content": prompt}],
224
- max_tokens=2000,
225
- temperature=temperature,
226
- stream=True,
227
- ):
228
- if message.choices and message.choices[0].delta and message.choices[0].delta.content:
229
- chunk = message.choices[0].delta.content
230
- response += chunk
231
- yield response, "" # Yield accumulated response with an empty string for consistency
232
 
233
- def vote(data: gr.LikeData):
234
- if data.liked:
235
- print(f"You upvoted this response: {data.value}")
 
 
 
 
 
 
 
 
 
 
236
  else:
237
- print(f"You downvoted this response: {data.value}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  css = """
240
  /* Add your custom CSS here */
@@ -245,34 +247,18 @@ demo = gr.ChatInterface(
245
  additional_inputs=[
246
  gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[1]),
247
  gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
248
- gr.Slider(minimum=1, maximum=5, value=1, step=1, label="Number of API Calls"),
249
  gr.Checkbox(label="Use Web Search", value=False)
250
  ],
251
  title="AI-powered Web Search and PDF Chat Assistant",
252
  description="Chat with your PDFs or use web search to answer questions.",
253
- theme=gr.themes.Soft(
254
- primary_hue="orange",
255
- secondary_hue="amber",
256
- neutral_hue="gray",
257
- font=[gr.themes.GoogleFont("Exo"), "ui-sans-serif", "system-ui", "sans-serif"]
258
- ).set(
259
- body_background_fill_dark="#0c0505",
260
- block_background_fill_dark="#0c0505",
261
- block_border_width="1px",
262
- block_title_background_fill_dark="#1b0f0f",
263
- input_background_fill_dark="#140b0b",
264
- button_secondary_background_fill_dark="#140b0b",
265
- border_color_accent_dark="#1b0f0f",
266
- border_color_primary_dark="#1b0f0f",
267
- background_fill_secondary_dark="#0c0505",
268
- color_accent_soft_dark="transparent",
269
- code_background_fill_dark="#140b0b"
270
- ),
271
  css=css,
272
  examples=[
273
- ["Tell me about the contents of the uploaded PDFs."],
274
- ["What are the main topics discussed in the documents?"],
275
- ["Can you summarize the key points from the PDFs?"]
 
276
  ],
277
  cache_examples=False,
278
  analytics_enabled=False,
@@ -281,6 +267,7 @@ demo = gr.ChatInterface(
281
  # Add file upload functionality
282
  with demo:
283
  gr.Markdown("## Upload PDF Documents")
 
284
  with gr.Row():
285
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
286
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
@@ -302,4 +289,4 @@ with demo:
302
  )
303
 
304
  if __name__ == "__main__":
305
- demo.launch(share=True)
 
20
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
21
 
22
  MODELS = [
23
+ "google/gemma-2-9b",
24
  "mistralai/Mixtral-8x7B-Instruct-v0.1",
25
+ "mistralai/Mistral-7B-Instruct-v0.3",
26
  "microsoft/Phi-3-mini-4k-instruct"
27
  ]
28
 
 
78
 
79
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}."
80
 
81
+ def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=3, temperature=0.2, stop_clicked=None):
82
  print(f"Starting generate_chunked_response with {num_calls} calls")
83
  client = InferenceClient(model, token=huggingface_token)
84
+ full_responses = []
85
  messages = [{"role": "user", "content": prompt}]
86
 
87
  for i in range(num_calls):
88
  print(f"Starting API call {i+1}")
89
+ if (isinstance(stop_clicked, gr.State) and stop_clicked.value) or stop_clicked:
90
  print("Stop clicked, breaking loop")
91
  break
92
  try:
93
+ response = ""
94
  for message in client.chat_completion(
95
  messages=messages,
96
  max_tokens=max_tokens,
97
  temperature=temperature,
98
  stream=True,
99
  ):
100
+ if (isinstance(stop_clicked, gr.State) and stop_clicked.value) or stop_clicked:
101
  print("Stop clicked during streaming, breaking")
102
  break
103
  if message.choices and message.choices[0].delta and message.choices[0].delta.content:
104
  chunk = message.choices[0].delta.content
105
+ response += chunk
106
+ print(f"API call {i+1} response: {response[:100]}...")
107
+ full_responses.append(response)
108
  except Exception as e:
109
  print(f"Error in generating response: {str(e)}")
110
 
111
+ # Combine responses and clean up
112
+ combined_response = " ".join(full_responses)
113
+ clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', combined_response, flags=re.DOTALL)
114
  clean_response = clean_response.replace("Using the following context:", "").strip()
115
  clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
116
 
117
+ # Split the response into main content and sources
118
+ parts = re.split(r'\n\s*Sources:\s*\n', clean_response, flags=re.IGNORECASE, maxsplit=1)
119
+ main_content = parts[0].strip()
120
+ sources = parts[1].strip() if len(parts) > 1 else ""
121
+
122
+ # Process main content
123
+ paragraphs = main_content.split('\n\n')
124
  unique_paragraphs = []
125
  for paragraph in paragraphs:
126
  if paragraph not in unique_paragraphs:
 
127
  unique_sentences = []
128
+ sentences = paragraph.split('. ')
129
  for sentence in sentences:
130
  if sentence not in unique_sentences:
131
  unique_sentences.append(sentence)
132
  unique_paragraphs.append('. '.join(unique_sentences))
133
 
134
+ final_content = '\n\n'.join(unique_paragraphs)
135
 
136
+ # Process sources
137
+ if sources:
138
+ source_lines = sources.split('\n')
139
+ unique_sources = []
140
+ for line in source_lines:
141
+ if line.strip() and line not in unique_sources:
142
+ unique_sources.append(line)
143
+ final_sources = '\n'.join(unique_sources)
144
+ final_response = f"{final_content}\n\nSources:\n{final_sources}"
145
+ else:
146
+ final_response = final_content
147
+
148
+ # Remove any content after the sources
149
+ final_response = re.sub(r'(Sources:.*?)(?:\n\n|\Z).*', r'\1', final_response, flags=re.DOTALL)
150
+
151
  print(f"Final clean response: {final_response[:100]}...")
152
  return final_response
153
 
 
161
  ...,
162
  description="List of sources to cite. Should be an URL of the source."
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ def get_response_with_search(query, model, num_calls=3, temperature=0.2, stop_clicked=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  search_results = duckduckgo_search(query)
167
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
168
  for result in search_results if 'body' in result)
169
 
170
+ prompt = f"""<s>[INST] Using the following context:
171
  {context}
172
  Write a detailed and complete research document that fulfills the following user request: '{query}'
173
+ After writing the document, please provide a list of sources used in your response. [/INST]"""
174
 
175
+ generated_text = generate_chunked_response(prompt, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
176
+
177
+ # Clean the response
178
+ clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
179
+ clean_text = clean_text.replace("Using the following context:", "").strip()
180
 
181
+ # Split the content and sources
182
+ parts = clean_text.split("Sources:", 1)
183
+ main_content = parts[0].strip()
184
+ sources = parts[1].strip() if len(parts) > 1 else ""
185
+
186
+ return main_content, sources
 
 
 
 
 
 
187
 
188
+ def get_response_from_pdf(query, model, num_calls=3, temperature=0.2, stop_clicked=None):
189
  embed = get_embeddings()
190
  if os.path.exists("faiss_database"):
191
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
192
  else:
193
+ return "No documents available. Please upload PDF documents to answer questions."
 
194
 
195
  retriever = database.as_retriever()
196
  relevant_docs = retriever.get_relevant_documents(query)
197
  context_str = "\n".join([doc.page_content for doc in relevant_docs])
198
 
199
+ prompt = f"""<s>[INST] Using the following context from the PDF documents:
200
  {context_str}
201
+ Write a detailed and complete response that answers the following user question: '{query}'
202
+ Do not include a list of sources in your response. [/INST]"""
 
203
 
204
+ generated_text = generate_chunked_response(prompt, model, num_calls=num_calls, temperature=temperature, stop_clicked=stop_clicked)
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # Clean the response
207
+ clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
208
+ clean_text = clean_text.replace("Using the following context from the PDF documents:", "").strip()
209
+
210
+ return clean_text
211
+
212
+ def chatbot_interface(message, history, use_web_search, model, temperature):
213
+ if not message.strip(): # Check if the message is empty or just whitespace
214
+ return history
215
+
216
+ if use_web_search:
217
+ main_content, sources = get_response_with_search(message, model, temperature)
218
+ formatted_response = f"{main_content}\n\nSources:\n{sources}"
219
  else:
220
+ response = get_response_from_pdf(message, model, temperature)
221
+ formatted_response = response
222
+
223
+ # Check if the last message in history is the same as the current message
224
+ if history and history[-1][0] == message:
225
+ # Replace the last response instead of adding a new one
226
+ history[-1] = (message, formatted_response)
227
+ else:
228
+ # Add the new message-response pair
229
+ history.append((message, formatted_response))
230
+
231
+ return history
232
+
233
+
234
+ def respond(message, history, model, temperature, num_calls, use_web_search):
235
+ if use_web_search:
236
+ main_content, sources = get_response_with_search(message, model, num_calls=num_calls, temperature=temperature)
237
+ return f"{main_content}\n\nSources:\n{sources}"
238
+ else:
239
+ return get_response_from_pdf(message, model, num_calls=num_calls, temperature=temperature)
240
 
241
  css = """
242
  /* Add your custom CSS here */
 
247
  additional_inputs=[
248
  gr.Dropdown(choices=MODELS, label="Select Model", value=MODELS[1]),
249
  gr.Slider(minimum=0.1, maximum=1.0, value=0.2, step=0.1, label="Temperature"),
250
+ gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Number of API Calls"),
251
  gr.Checkbox(label="Use Web Search", value=False)
252
  ],
253
  title="AI-powered Web Search and PDF Chat Assistant",
254
  description="Chat with your PDFs or use web search to answer questions.",
255
+ theme=gr.themes.Soft(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  css=css,
257
  examples=[
258
+ ["What are the latest developments in AI?"],
259
+ ["Tell me about recent updates on GitHub"],
260
+ ["What are the best hotels in Galapagos, Ecuador?"],
261
+ ["Summarize recent advancements in Python programming"],
262
  ],
263
  cache_examples=False,
264
  analytics_enabled=False,
 
267
  # Add file upload functionality
268
  with demo:
269
  gr.Markdown("## Upload PDF Documents")
270
+
271
  with gr.Row():
272
  file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"])
273
  parser_dropdown = gr.Dropdown(choices=["pypdf", "llamaparse"], label="Select PDF Parser", value="llamaparse")
 
289
  )
290
 
291
  if __name__ == "__main__":
292
+ demo.launch(share=True)