Spaces:
Runtime error
Runtime error
development branch (#7)
Browse files* fix relative import
* add embeddings requirement
* update openai embeddings requirements...
* format responses appropriately
* add markdown response
* Fix newline formatting
* add threshold and top_k
* update response
* fix merge conflict
- buster/chatbot.py +41 -6
buster/chatbot.py
CHANGED
@@ -12,13 +12,16 @@ logging.basicConfig(level=logging.INFO)
|
|
12 |
|
13 |
|
14 |
# search through the reviews for a specific product
|
15 |
-
def rank_documents(df: pd.DataFrame, query: str, top_k: int =
|
16 |
product_embedding = get_embedding(
|
17 |
query,
|
18 |
engine=EMBEDDING_MODEL,
|
19 |
)
|
20 |
df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
|
21 |
|
|
|
|
|
|
|
22 |
if top_k == -1:
|
23 |
# return all results
|
24 |
n = len(df)
|
@@ -28,13 +31,43 @@ def rank_documents(df: pd.DataFrame, query: str, top_k: int = 3) -> pd.DataFrame
|
|
28 |
|
29 |
|
30 |
def engineer_prompt(question: str, documents: list[str]) -> str:
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
35 |
# rank the documents, get the highest scoring doc and generate the prompt
|
36 |
-
candidates = rank_documents(df, query=question, top_k=
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
documents = candidates.text.to_list()
|
|
|
38 |
prompt = engineer_prompt(question, documents)
|
39 |
|
40 |
logger.info(f"querying GPT...")
|
@@ -58,12 +91,14 @@ def answer_question(question: str, df) -> str:
|
|
58 |
GPT Response:\n{response_text}
|
59 |
"""
|
60 |
)
|
61 |
-
return response_text
|
|
|
62 |
except Exception as e:
|
63 |
import traceback
|
64 |
|
65 |
logging.error(traceback.format_exc())
|
66 |
-
|
|
|
67 |
|
68 |
|
69 |
def load_embeddings(path: str) -> pd.DataFrame:
|
|
|
12 |
|
13 |
|
14 |
# search through the reviews for a specific product
|
15 |
+
def rank_documents(df: pd.DataFrame, query: str, top_k: int = 1, thresh: float = None) -> pd.DataFrame:
|
16 |
product_embedding = get_embedding(
|
17 |
query,
|
18 |
engine=EMBEDDING_MODEL,
|
19 |
)
|
20 |
df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
|
21 |
|
22 |
+
if thresh:
|
23 |
+
df = df[df.similarity > thresh]
|
24 |
+
|
25 |
if top_k == -1:
|
26 |
# return all results
|
27 |
n = len(df)
|
|
|
31 |
|
32 |
|
33 |
def engineer_prompt(question: str, documents: list[str]) -> str:
|
34 |
+
documents_str = " ".join(documents)
|
35 |
+
if len(documents_str) > 3000:
|
36 |
+
logger.info("truncating documents to fit...")
|
37 |
+
documents_str = documents_str[0:3000]
|
38 |
+
return documents_str + "\nNow answer the following question:\n" + question
|
39 |
+
|
40 |
+
|
41 |
+
def format_response(response_text, sources_url=None):
|
42 |
+
|
43 |
+
response = f"{response_text}\n"
|
44 |
+
|
45 |
+
if sources_url:
|
46 |
+
response += f"<br><br>Here are the sources I used to answer your question:\n"
|
47 |
+
for url in sources_url:
|
48 |
+
response += f"<br>[{url}]({url})\n"
|
49 |
|
50 |
+
response += "<br><br>"
|
51 |
+
response += """
|
52 |
+
```
|
53 |
+
I'm a bot 🤖 and not always perfect.
|
54 |
+
For more info, view the full documentation here (https://docs.mila.quebec/) or contact [email protected]
|
55 |
+
```
|
56 |
+
"""
|
57 |
+
return response
|
58 |
|
59 |
+
|
60 |
+
def answer_question(question: str, df, top_k: int = 1, thresh: float = None) -> str:
|
61 |
# rank the documents, get the highest scoring doc and generate the prompt
|
62 |
+
candidates = rank_documents(df, query=question, top_k=top_k, thresh=thresh)
|
63 |
+
|
64 |
+
logger.info(f"candidate responses: {candidates}")
|
65 |
+
|
66 |
+
if len(candidates) == 0:
|
67 |
+
return format_response("I did not find any relevant documentation related to your question.")
|
68 |
+
|
69 |
documents = candidates.text.to_list()
|
70 |
+
sources_url = candidates.url.to_list()
|
71 |
prompt = engineer_prompt(question, documents)
|
72 |
|
73 |
logger.info(f"querying GPT...")
|
|
|
91 |
GPT Response:\n{response_text}
|
92 |
"""
|
93 |
)
|
94 |
+
return format_response(response_text, sources_url)
|
95 |
+
|
96 |
except Exception as e:
|
97 |
import traceback
|
98 |
|
99 |
logging.error(traceback.format_exc())
|
100 |
+
response = "Oops, something went wrong. Try again later!"
|
101 |
+
return format_response(response)
|
102 |
|
103 |
|
104 |
def load_embeddings(path: str) -> pd.DataFrame:
|