jerpint commited on
Commit
fa9ac7e
·
unverified ·
1 Parent(s): 8756061

Add slackbot support (#12)

Browse files

* fix relative import

* add embeddings requirement

* update openai embeddings requirements...

* format responses appropriately

* add markdown response

* Fix newline formatting

* add threshold and top_k

* update response

* fix merge conflict

* Add slackbot

* refactor with a nice config interface

* add TODO

* isort

* add dataclass for chatbot config

* black

* Add support for orion bot

* format text

* Update docs

* use default_factory for dataclass

* Update app home tab

* update unk tokens

* move init to function

* Add logging

Files changed (2) hide show
  1. app.py +144 -0
  2. buster/chatbot.py +189 -85
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from slack_bolt import App
4
+
5
+ from buster.chatbot import Chatbot, ChatbotConfig
6
+
7
+ MILA_CLUSTER_CHANNEL = "C04LR4H9KQA"
8
+ ORION_CHANNEL = "C04LYHGUYB0"
9
+
10
+ buster_cfg = ChatbotConfig(
11
+ documents_csv="buster/data/document_embeddings.csv",
12
+ unknown_prompt="This doesn't seem to be related to cluster usage. I am not sure how to answer.",
13
+ embedding_model="text-embedding-ada-002",
14
+ top_k=3,
15
+ thresh=0.7,
16
+ max_chars=3000,
17
+ completion_kwargs={
18
+ "engine": "text-davinci-003",
19
+ "max_tokens": 200,
20
+ },
21
+ separator="\n",
22
+ link_format="slack",
23
+ text_after_response="""I'm a bot 🤖 and not always perfect.
24
+ For more info, view the full documentation here (https://docs.mila.quebec/) or contact [email protected]
25
+ """,
26
+ text_before_prompt="""
27
+ You are a slack chatbot assistant answering technical questions about a cluster.
28
+ Make sure to format your answers in Markdown format, including code block and snippets.
29
+ Do not include any links to urls or hyperlinks in your answers.
30
+
31
+ If you do not know the answer to a question, or if it is completely irrelevant to cluster usage, simply reply with:
32
+
33
+ 'This doesn't seem to be related to cluster usage.'
34
+
35
+ For example:
36
+
37
+ What is the meaning of life on the cluster?
38
+
39
+ This doesn't seem to be related to cluster usage.
40
+
41
+ Now answer the following question:
42
+ """,
43
+ )
44
+ buster_chatbot = Chatbot(buster_cfg)
45
+
46
+ orion_cfg = ChatbotConfig(
47
+ documents_csv="buster/data/document_embeddings_orion.csv",
48
+ unknown_prompt="This doesn't seem to be related to the orion library. I am not sure how to answer.",
49
+ embedding_model="text-embedding-ada-002",
50
+ top_k=3,
51
+ thresh=0.7,
52
+ max_chars=3000,
53
+ completion_kwargs={
54
+ "engine": "text-davinci-003",
55
+ "max_tokens": 200,
56
+ },
57
+ separator="\n",
58
+ link_format="slack",
59
+ text_after_response="I'm a bot 🤖 and not always perfect.",
60
+ text_before_prompt="""You are a slack chatbot assistant answering technical questions about orion, a hyperparameter optimization library written in python.
61
+ Make sure to format your answers in Markdown format, including code block and snippets.
62
+ Do not include any links to urls or hyperlinks in your answers.
63
+
64
+ If you do not know the answer to a question, or if it is completely irrelevant to the library usage, simply reply with:
65
+
66
+ 'This doesn't seem to be related to the orion library.'
67
+
68
+ For example:
69
+
70
+ What is the meaning of life for orion?
71
+
72
+ This doesn't seem to be related to cluster usage.
73
+
74
+ Now answer the following question:
75
+ """,
76
+ )
77
+ orion_chatbot = Chatbot(orion_cfg)
78
+
79
+ app = App(token=os.environ.get("SLACK_BOT_TOKEN"), signing_secret=os.environ.get("SLACK_SIGNING_SECRET"))
80
+
81
+
82
+ @app.event("app_mention")
83
+ def respond_to_question(event, say):
84
+ print(event)
85
+
86
+ # user's text
87
+ text = event["text"]
88
+ channel = event["channel"]
89
+
90
+ if channel == MILA_CLUSTER_CHANNEL:
91
+ print("*******using BUSTER********")
92
+ answer = buster_chatbot.process_input(text)
93
+ elif channel == ORION_CHANNEL:
94
+ print("*******using ORION********")
95
+ answer = orion_chatbot.process_input(text)
96
+
97
+ # responds to the message in the thread
98
+ thread_ts = event["event_ts"]
99
+
100
+ say(text=answer, thread_ts=thread_ts)
101
+
102
+
103
+ @app.event("app_home_opened")
104
+ def update_home_tab(client, event, logger):
105
+ try:
106
+ # views.publish is the method that your app uses to push a view to the Home tab
107
+ client.views_publish(
108
+ # the user that opened your app's app home
109
+ user_id=event["user"],
110
+ # the view object that appears in the app home
111
+ view={
112
+ "type": "home",
113
+ "callback_id": "home_view",
114
+ # body of the view
115
+ "blocks": [
116
+ {"type": "section", "text": {"type": "mrkdwn", "text": "*Hello, I'm _BusterBot_* :tada:"}},
117
+ {"type": "divider"},
118
+ {
119
+ "type": "section",
120
+ "text": {
121
+ "type": "mrkdwn",
122
+ "text": (
123
+ "I am a chatbot 🤖 designed to answer questions related to technical documentation.\n\n"
124
+ "I use OpenAI's GPT models to target which relevant sections of documentation are relevant and respond with.\n"
125
+ "I am open-sourced, and my code is available on github: https://github.com/jerpint/buster\n\n"
126
+ "For more information, contact either Jeremy or Hadrien from the AMLRT team.\n"
127
+ ),
128
+ },
129
+ },
130
+ # {
131
+ # "type": "actions",
132
+ # "elements": [{"type": "button", "text": {"type": "plain_text", "text": "Click me!"}}],
133
+ # },
134
+ ],
135
+ },
136
+ )
137
+
138
+ except Exception as e:
139
+ logger.error(f"Error publishing home tab: {e}")
140
+
141
+
142
+ # Start your app
143
+ if __name__ == "__main__":
144
+ app.start(port=int(os.environ.get("PORT", 3000)))
buster/chatbot.py CHANGED
@@ -1,8 +1,10 @@
1
  import logging
 
2
 
3
  import numpy as np
4
  import openai
5
  import pandas as pd
 
6
  from openai.embeddings_utils import cosine_similarity, get_embedding
7
 
8
  from buster.docparser import EMBEDDING_MODEL
@@ -11,107 +13,209 @@ logger = logging.getLogger(__name__)
11
  logging.basicConfig(level=logging.INFO)
12
 
13
 
14
- # search through the reviews for a specific product
15
- def rank_documents(df: pd.DataFrame, query: str, top_k: int = 1, thresh: float = None) -> pd.DataFrame:
16
- product_embedding = get_embedding(
17
- query,
18
- engine=EMBEDDING_MODEL,
19
- )
20
- df["similarity"] = df.embedding.apply(lambda x: cosine_similarity(x, product_embedding))
21
 
22
- if thresh:
23
- df = df[df.similarity > thresh]
24
 
25
- if top_k == -1:
26
- # return all results
27
- n = len(df)
 
 
 
28
 
29
- results = df.sort_values("similarity", ascending=False).head(top_k)
30
- return results
31
 
 
 
 
 
 
 
 
 
32
 
33
- def engineer_prompt(question: str, documents: list[str]) -> str:
34
- documents_str = " ".join(documents)
35
- if len(documents_str) > 3000:
36
- logger.info("truncating documents to fit...")
37
- documents_str = documents_str[0:3000]
38
- return documents_str + "\nNow answer the following question:\n" + question
 
 
 
 
 
39
 
 
 
 
 
 
40
 
41
- def format_response(response_text, sources_url=None):
 
42
 
43
- response = f"{response_text}\n"
 
 
44
 
45
- if sources_url:
46
- response += f"<br><br>Here are the sources I used to answer your question:\n"
47
- for url in sources_url:
48
- response += f"<br>[{url}]({url})\n"
49
 
50
- response += "<br><br>"
51
- response += """
52
- ```
53
- I'm a bot 🤖 and not always perfect.
54
- For more info, view the full documentation here (https://docs.mila.quebec/) or contact [email protected]
55
- ```
56
- """
57
- return response
58
-
59
-
60
- def answer_question(question: str, df, top_k: int = 1, thresh: float = None) -> str:
61
- # rank the documents, get the highest scoring doc and generate the prompt
62
- candidates = rank_documents(df, query=question, top_k=top_k, thresh=thresh)
63
-
64
- logger.info(f"candidate responses: {candidates}")
65
-
66
- if len(candidates) == 0:
67
- return format_response("I did not find any relevant documentation related to your question.")
68
-
69
- documents = candidates.text.to_list()
70
- sources_url = candidates.url.to_list()
71
- prompt = engineer_prompt(question, documents)
72
-
73
- logger.info(f"querying GPT...")
74
- logger.info(f"User Question:\n{question}")
75
- # Call the API to generate a response
76
- try:
77
- response = openai.Completion.create(
78
- engine="text-davinci-003",
79
- prompt=prompt,
80
- max_tokens=200,
81
- # temperature=0,
82
- # top_p=0,
83
- frequency_penalty=1,
84
- presence_penalty=1,
85
- )
86
 
87
- # Get the response text
88
- response_text = response["choices"][0]["text"]
89
- logger.info(
90
- f"""
91
- GPT Response:\n{response_text}
92
  """
93
- )
94
- return format_response(response_text, sources_url)
95
 
96
- except Exception as e:
97
- import traceback
98
 
99
- logging.error(traceback.format_exc())
100
- response = "Oops, something went wrong. Try again later!"
101
- return format_response(response)
 
 
102
 
 
103
 
104
- def load_embeddings(path: str) -> pd.DataFrame:
105
- logger.info(f"loading embeddings from {path}...")
106
- df = pd.read_csv(path)
107
- df["embedding"] = df.embedding.apply(eval).apply(np.array)
108
- logger.info(f"embeddings loaded.")
109
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
 
 
 
 
 
 
 
 
111
 
112
- if __name__ == "__main__":
113
- # we generate the embeddings using docparser.py
114
- df = load_embeddings("data/document_embeddings.csv")
 
 
 
115
 
116
- question = "Where should I put my datasets when I am running a job?"
117
- response = answer_question(question, df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from dataclasses import dataclass, field
3
 
4
  import numpy as np
5
  import openai
6
  import pandas as pd
7
+ from omegaconf import OmegaConf
8
  from openai.embeddings_utils import cosine_similarity, get_embedding
9
 
10
  from buster.docparser import EMBEDDING_MODEL
 
13
  logging.basicConfig(level=logging.INFO)
14
 
15
 
16
+ def load_documents(path: str) -> pd.DataFrame:
17
+ logger.info(f"loading embeddings from {path}...")
18
+ df = pd.read_csv(path)
19
+ df["embedding"] = df.embedding.apply(eval).apply(np.array)
20
+ logger.info(f"embeddings loaded.")
21
+ return df
 
22
 
 
 
23
 
24
+ class Chatbot:
25
+ def __init__(self, cfg: OmegaConf):
26
+ # TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
27
+ self.cfg = cfg
28
+ self._init_documents()
29
+ self._init_unk_embedding()
30
 
31
+ def _init_documents(self):
32
+ self.documents = load_documents(self.cfg.documents_csv)
33
 
34
+ def _init_unk_embedding(self):
35
+ logger.info("Generating UNK token...")
36
+ unknown_prompt = self.cfg.unknown_prompt
37
+ engine = self.cfg.embedding_model
38
+ self.unk_embedding = get_embedding(
39
+ unknown_prompt,
40
+ engine=engine,
41
+ )
42
 
43
+ def rank_documents(
44
+ self,
45
+ documents: pd.DataFrame,
46
+ query: str,
47
+ ) -> pd.DataFrame:
48
+ """
49
+ Compare the question to the series of documents and return the best matching documents.
50
+ """
51
+ top_k = self.cfg.top_k
52
+ thresh = self.cfg.thresh
53
+ engine = self.cfg.embedding_model # EMBEDDING_MODEL
54
 
55
+ query_embedding = get_embedding(
56
+ query,
57
+ engine=engine,
58
+ )
59
+ documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
60
 
61
+ # sort the matched_documents by score
62
+ matched_documents = documents.sort_values("similarity", ascending=False)
63
 
64
+ # limit search to top_k matched_documents.
65
+ top_k = len(matched_documents) if top_k == -1 else top_k
66
+ matched_documents = matched_documents.head(top_k)
67
 
68
+ # log matched_documents to the console
69
+ logger.info(f"matched documents before thresh: {matched_documents}")
 
 
70
 
71
+ # filter out matched_documents using a threshold
72
+ if thresh:
73
+ matched_documents = matched_documents[matched_documents.similarity > thresh]
74
+ logger.info(f"matched documents after thresh: {matched_documents}")
75
+
76
+ return matched_documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ def prepare_prompt(self, question: str, candidates: pd.DataFrame) -> str:
79
+ """
80
+ Prepare the prompt with prompt engineering.
 
 
81
  """
 
 
82
 
83
+ max_chars = self.cfg.max_chars
84
+ text_before_prompt = self.cfg.text_before_prompt
85
 
86
+ documents_list = candidates.text.to_list()
87
+ documents_str = " ".join(documents_list)
88
+ if len(documents_str) > max_chars:
89
+ logger.info("truncating documents to fit...")
90
+ documents_str = documents_str[0:max_chars]
91
 
92
+ return documents_str + text_before_prompt + question
93
 
94
+ def generate_response(self, prompt: str, matched_documents: pd.DataFrame) -> str:
95
+ """
96
+ Generate a response based on the retrieved documents.
97
+ """
98
+ if len(matched_documents) == 0:
99
+ # No matching documents were retrieved, return
100
+ response_text = "I did not find any relevant documentation related to your question."
101
+ return response_text
102
+
103
+ logger.info(f"querying GPT...")
104
+ # Call the API to generate a response
105
+ try:
106
+ completion_kwargs = self.cfg.completion_kwargs
107
+ completion_kwargs["prompt"] = prompt
108
+ response = openai.Completion.create(**completion_kwargs)
109
+
110
+ # Get the response text
111
+ response_text = response["choices"][0]["text"]
112
+ logger.info(f"GPT Response:\n{response_text}")
113
+ return response_text
114
+
115
+ except Exception as e:
116
+ # log the error and return a generic response instead.
117
+ import traceback
118
+
119
+ logging.error(traceback.format_exc())
120
+ response_text = "Oops, something went wrong. Try again later!"
121
+ return response_text
122
+
123
+ def add_sources(self, response: str, matched_documents: pd.DataFrame):
124
+ """
125
+ Add sources fromt the matched documents to the response.
126
+ """
127
+ sep = self.cfg.separator # \n
128
+ format = self.cfg.link_format
129
+
130
+ urls = matched_documents.url.to_list()
131
+ names = matched_documents.name.to_list()
132
+ similarities = matched_documents.similarity.to_list()
133
 
134
+ response += f"{sep}{sep}Here are the sources I used to answer your question:\n"
135
+ for url, name, similarity in zip(urls, names, similarities):
136
+ if format == "markdown":
137
+ response += f"{sep}[{name}]({url}){sep}"
138
+ elif format == "slack":
139
+ response += f"• <{url}|{name}>, score: {similarity:2.3f}{sep}"
140
+ else:
141
+ raise ValueError(f"{format} is not a valid URL format.")
142
 
143
+ return response
144
+
145
+ def format_response(self, response: str, matched_documents: pd.DataFrame) -> str:
146
+ """
147
+ Format the response by adding the sources if necessary, and a disclaimer prompt.
148
+ """
149
 
150
+ sep = self.cfg.separator
151
+ text_after_response = self.cfg.text_after_response
152
+
153
+ if len(matched_documents) > 0:
154
+ # we have matched documents, now we check to see if the answer is meaningful
155
+ response_embedding = get_embedding(
156
+ response,
157
+ engine=EMBEDDING_MODEL,
158
+ )
159
+ score = cosine_similarity(response_embedding, self.unk_embedding)
160
+ logger.info(f"UNK score: {score}")
161
+ if score < 0.9:
162
+ # Liekly that the answer is meaningful, add the top sources
163
+ response = self.add_sources(response, matched_documents=matched_documents)
164
+
165
+ response += f"{sep}{sep}{sep}{text_after_response}{sep}"
166
+
167
+ return response
168
+
169
+ def process_input(self, question: str) -> str:
170
+ """
171
+ Main function to process the input question and generate a formatted output.
172
+ """
173
+
174
+ logger.info(f"User Question:\n{question}")
175
+
176
+ matched_documents = self.rank_documents(documents=self.documents, query=question)
177
+ prompt = self.prepare_prompt(question, matched_documents)
178
+ response = self.generate_response(prompt, matched_documents)
179
+ formatted_output = self.format_response(response, matched_documents)
180
+
181
+ return formatted_output
182
+
183
+
184
+ @dataclass
185
+ class ChatbotConfig:
186
+ """Configuration object for a chatbot.
187
+
188
+ documents_csv: Path to the csv file containing the documents and their embeddings.
189
+ embedding_model: OpenAI model to use to get embeddings.
190
+ top_k: Max number of documents to retrieve, ordered by cosine similarity
191
+ thresh: threshold for cosine similarity to be considered
192
+ max_chars: maximum number of characters the retrieved documents can be. Will truncate otherwise.
193
+ completion_kwargs: kwargs for the OpenAI.Completion() method
194
+ separator: the separator to use, can be either "\n" or <p> depending on rendering.
195
+ link_format: the type of format to render links with, e.g. slack or markdown
196
+ unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
197
+ text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
198
+ text_after_response: Generic response to add the the chatbot's reply.
199
+ """
200
+
201
+ documents_csv: str = "buster/data/document_embeddings.csv"
202
+ embedding_model: str = "text-embedding-ada-002"
203
+ top_k: int = 3
204
+ thresh: float = 0.7
205
+ max_chars: int = 3000
206
+
207
+ completion_kwargs: dict = field(
208
+ default_factory=lambda: {
209
+ "engine": "text-davinci-003",
210
+ "max_tokens": 200,
211
+ "temperature": None,
212
+ "top_p": None,
213
+ "frequency_penalty": 1,
214
+ "presence_penalty": 1,
215
+ }
216
+ )
217
+ separator: str = "\n"
218
+ link_format: str = "slack"
219
+ unknown_prompt: str = "I Don't know how to answer your question."
220
+ text_before_prompt: str = "I'm a chatbot, bleep bloop."
221
+ text_after_response: str = "Answer the following question:\n"