Shreyas094 commited on
Commit
ea42356
·
verified ·
1 Parent(s): c7c58a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -122,7 +122,7 @@ class CitingSources(BaseModel):
122
  description="List of sources to cite. Should be an URL of the source."
123
  )
124
 
125
- def get_response_from_pdf(query, temperature=0.7, repetition_penalty=1.1):
126
  embed = get_embeddings()
127
  if os.path.exists("faiss_database"):
128
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -138,7 +138,7 @@ def get_response_from_pdf(query, temperature=0.7, repetition_penalty=1.1):
138
  Write a detailed and complete response that answers the following user question: '{query}'
139
  Do not include a list of sources in your response. [/INST]"""
140
 
141
- generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
142
 
143
  # Clean the response
144
  clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
@@ -146,7 +146,7 @@ Do not include a list of sources in your response. [/INST]"""
146
 
147
  return clean_text
148
 
149
- def get_response_with_search(query, temperature=0.7, repetition_penalty=1.1):
150
  search_results = duckduckgo_search(query)
151
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
152
  for result in search_results if 'body' in result)
@@ -156,7 +156,7 @@ def get_response_with_search(query, temperature=0.7, repetition_penalty=1.1):
156
  Write a detailed and complete research document that fulfills the following user request: '{query}'
157
  After writing the document, please provide a list of sources used in your response. [/INST]"""
158
 
159
- generated_text = generate_chunked_response(prompt, temperature=temperature, repetition_penalty=repetition_penalty)
160
 
161
  # Clean the response
162
  clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
@@ -169,12 +169,12 @@ After writing the document, please provide a list of sources used in your respon
169
 
170
  return main_content, sources
171
 
172
- def chatbot_interface(message, history, use_web_search, temperature, repetition_penalty):
173
  if use_web_search:
174
- main_content, sources = get_response_with_search(message, temperature, repetition_penalty)
175
  formatted_response = f"{main_content}\n\nSources:\n{sources}"
176
  else:
177
- response = get_response_from_pdf(message, temperature, repetition_penalty)
178
  formatted_response = response
179
 
180
  history.append((message, formatted_response))
@@ -197,8 +197,7 @@ with gr.Blocks() as demo:
197
  use_web_search = gr.Checkbox(label="Use Web Search", value=False)
198
 
199
  with gr.Row():
200
- temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
201
- repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Repetition Penalty")
202
 
203
  submit = gr.Button("Submit")
204
 
@@ -213,10 +212,10 @@ with gr.Blocks() as demo:
213
  )
214
 
215
  submit.click(chatbot_interface,
216
- inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
217
  outputs=[chatbot])
218
  msg.submit(chatbot_interface,
219
- inputs=[msg, chatbot, use_web_search, temperature_slider, repetition_penalty_slider],
220
  outputs=[chatbot])
221
 
222
  gr.Markdown(
 
122
  description="List of sources to cite. Should be an URL of the source."
123
  )
124
 
125
+ def get_response_from_pdf(query, temperature=0.7):
126
  embed = get_embeddings()
127
  if os.path.exists("faiss_database"):
128
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
138
  Write a detailed and complete response that answers the following user question: '{query}'
139
  Do not include a list of sources in your response. [/INST]"""
140
 
141
+ generated_text = generate_chunked_response(prompt, temperature=temperature)
142
 
143
  # Clean the response
144
  clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
 
146
 
147
  return clean_text
148
 
149
+ def get_response_with_search(query, temperature=0.7):
150
  search_results = duckduckgo_search(query)
151
  context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
152
  for result in search_results if 'body' in result)
 
156
  Write a detailed and complete research document that fulfills the following user request: '{query}'
157
  After writing the document, please provide a list of sources used in your response. [/INST]"""
158
 
159
+ generated_text = generate_chunked_response(prompt, temperature=temperature)
160
 
161
  # Clean the response
162
  clean_text = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', generated_text, flags=re.DOTALL)
 
169
 
170
  return main_content, sources
171
 
172
+ def chatbot_interface(message, history, use_web_search, temperature):
173
  if use_web_search:
174
+ main_content, sources = get_response_with_search(message, temperature)
175
  formatted_response = f"{main_content}\n\nSources:\n{sources}"
176
  else:
177
+ response = get_response_from_pdf(message, temperature)
178
  formatted_response = response
179
 
180
  history.append((message, formatted_response))
 
197
  use_web_search = gr.Checkbox(label="Use Web Search", value=False)
198
 
199
  with gr.Row():
200
+ temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Temperature")
 
201
 
202
  submit = gr.Button("Submit")
203
 
 
212
  )
213
 
214
  submit.click(chatbot_interface,
215
+ inputs=[msg, chatbot, use_web_search, temperature_slider],
216
  outputs=[chatbot])
217
  msg.submit(chatbot_interface,
218
+ inputs=[msg, chatbot, use_web_search, temperature_slider],
219
  outputs=[chatbot])
220
 
221
  gr.Markdown(