Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,8 +3,12 @@ import json
|
|
| 3 |
import re
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 6 |
from tempfile import NamedTemporaryFile
|
| 7 |
from typing import List
|
|
|
|
| 8 |
from langchain_core.prompts import ChatPromptTemplate
|
| 9 |
from langchain_community.vectorstores import FAISS
|
| 10 |
from langchain_community.document_loaders import PyPDFLoader
|
|
@@ -119,7 +123,78 @@ def is_related_to_history(question, history, threshold=0.3):
|
|
| 119 |
similarity = get_similarity(question, history_text)
|
| 120 |
return similarity > threshold
|
| 121 |
|
| 122 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
global conversation_history
|
| 124 |
|
| 125 |
if not question:
|
|
@@ -129,17 +204,21 @@ def ask_question(question, temperature, top_p, repetition_penalty):
|
|
| 129 |
answer = memory_database[question]
|
| 130 |
else:
|
| 131 |
embed = get_embeddings()
|
| 132 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
| 133 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 134 |
|
| 135 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
| 136 |
|
| 137 |
-
if
|
| 138 |
-
|
|
|
|
| 139 |
else:
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
| 145 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
|
@@ -220,9 +299,10 @@ with gr.Blocks() as demo:
|
|
| 220 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
| 221 |
top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
|
| 222 |
repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
|
|
|
|
| 223 |
|
| 224 |
def chat(question, history):
|
| 225 |
-
answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value)
|
| 226 |
history.append((question, answer))
|
| 227 |
return "", history
|
| 228 |
|
|
@@ -241,4 +321,4 @@ with gr.Blocks() as demo:
|
|
| 241 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
| 242 |
|
| 243 |
if __name__ == "__main__":
|
| 244 |
-
demo.launch()
|
|
|
|
| 3 |
import re
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
| 6 |
+
import requests
|
| 7 |
+
import random
|
| 8 |
+
import urllib.parse
|
| 9 |
from tempfile import NamedTemporaryFile
|
| 10 |
from typing import List
|
| 11 |
+
from bs4 import BeautifulSoup
|
| 12 |
from langchain_core.prompts import ChatPromptTemplate
|
| 13 |
from langchain_community.vectorstores import FAISS
|
| 14 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
|
| 123 |
similarity = get_similarity(question, history_text)
|
| 124 |
return similarity > threshold
|
| 125 |
|
| 126 |
+
def extract_text_from_webpage(html):
|
| 127 |
+
soup = BeautifulSoup(html, 'html.parser')
|
| 128 |
+
for script in soup(["script", "style"]):
|
| 129 |
+
script.extract() # Remove scripts and styles
|
| 130 |
+
text = soup.get_text()
|
| 131 |
+
lines = (line.strip() for line in text.splitlines())
|
| 132 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
| 133 |
+
text = '\n'.join(chunk for chunk in chunks if chunk)
|
| 134 |
+
return text
|
| 135 |
+
|
| 136 |
+
_useragent_list = [
|
| 137 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
| 138 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
| 139 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
|
| 140 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Edge/91.0.864.59 Safari/537.36",
|
| 141 |
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
|
| 142 |
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Safari/537.36",
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
def google_search(term, num_results=5, lang="en", timeout=5, safe="active", ssl_verify=None):
|
| 146 |
+
escaped_term = urllib.parse.quote_plus(term)
|
| 147 |
+
start = 0
|
| 148 |
+
all_results = []
|
| 149 |
+
max_chars_per_page = 8000 # Limit the number of characters from each webpage to stay under the token limit
|
| 150 |
+
|
| 151 |
+
with requests.Session() as session:
|
| 152 |
+
while start < num_results:
|
| 153 |
+
try:
|
| 154 |
+
user_agent = random.choice(_useragent_list)
|
| 155 |
+
headers = {
|
| 156 |
+
'User-Agent': user_agent
|
| 157 |
+
}
|
| 158 |
+
resp = session.get(
|
| 159 |
+
url="https://www.google.com/search",
|
| 160 |
+
headers=headers,
|
| 161 |
+
params={
|
| 162 |
+
"q": term,
|
| 163 |
+
"num": num_results - start,
|
| 164 |
+
"hl": lang,
|
| 165 |
+
"start": start,
|
| 166 |
+
"safe": safe,
|
| 167 |
+
},
|
| 168 |
+
timeout=timeout,
|
| 169 |
+
verify=ssl_verify,
|
| 170 |
+
)
|
| 171 |
+
resp.raise_for_status()
|
| 172 |
+
except requests.exceptions.RequestException as e:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
soup = BeautifulSoup(resp.text, "html.parser")
|
| 176 |
+
result_block = soup.find_all("div", attrs={"class": "g"})
|
| 177 |
+
if not result_block:
|
| 178 |
+
break
|
| 179 |
+
for result in result_block:
|
| 180 |
+
link = result.find("a", href=True)
|
| 181 |
+
if link:
|
| 182 |
+
link = link["href"]
|
| 183 |
+
try:
|
| 184 |
+
webpage = session.get(link, headers=headers, timeout=timeout)
|
| 185 |
+
webpage.raise_for_status()
|
| 186 |
+
visible_text = extract_text_from_webpage(webpage.text)
|
| 187 |
+
if len(visible_text) > max_chars_per_page:
|
| 188 |
+
visible_text = visible_text[:max_chars_per_page] + "..."
|
| 189 |
+
all_results.append({"link": link, "text": visible_text})
|
| 190 |
+
except requests.exceptions.RequestException as e:
|
| 191 |
+
all_results.append({"link": link, "text": None})
|
| 192 |
+
else:
|
| 193 |
+
all_results.append({"link": None, "text": None})
|
| 194 |
+
start += len(result_block)
|
| 195 |
+
return all_results
|
| 196 |
+
|
| 197 |
+
def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
| 198 |
global conversation_history
|
| 199 |
|
| 200 |
if not question:
|
|
|
|
| 204 |
answer = memory_database[question]
|
| 205 |
else:
|
| 206 |
embed = get_embeddings()
|
|
|
|
| 207 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 208 |
|
| 209 |
history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
|
| 210 |
|
| 211 |
+
if web_search:
|
| 212 |
+
search_results = google_search(question)
|
| 213 |
+
context_str = "\n".join([result["text"] for result in search_results if result["text"]])
|
| 214 |
else:
|
| 215 |
+
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
| 216 |
+
if is_related_to_history(question, conversation_history):
|
| 217 |
+
context_str = "No additional context needed. Please refer to the conversation history."
|
| 218 |
+
else:
|
| 219 |
+
retriever = database.as_retriever()
|
| 220 |
+
relevant_docs = retriever.get_relevant_documents(question)
|
| 221 |
+
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 222 |
|
| 223 |
prompt_val = ChatPromptTemplate.from_template(prompt)
|
| 224 |
formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
|
|
|
|
| 299 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
| 300 |
top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
|
| 301 |
repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
|
| 302 |
+
web_search_checkbox = gr.Checkbox(label="Enable Web Search", value=False)
|
| 303 |
|
| 304 |
def chat(question, history):
|
| 305 |
+
answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value, web_search_checkbox.value)
|
| 306 |
history.append((question, answer))
|
| 307 |
return "", history
|
| 308 |
|
|
|
|
| 321 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
| 322 |
|
| 323 |
if __name__ == "__main__":
|
| 324 |
+
demo.launch()
|