Marc-Antoine Rondeau commited on
Commit
0e4a27a
Β·
1 Parent(s): 8b5fed9

Initial Formatters

Browse files
buster/chatbot.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  from dataclasses import dataclass, field
 
4
 
5
  import numpy as np
6
  import openai
@@ -9,6 +10,9 @@ import promptlayer
9
  from openai.embeddings_utils import cosine_similarity, get_embedding
10
 
11
  from buster.docparser import read_documents
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(level=logging.INFO)
@@ -149,53 +153,48 @@ class Chatbot:
149
  documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
150
  return text_before_documents + documents_str + text_before_prompt + question
151
 
152
- def get_gpt_response(self, **completion_kwargs):
153
  # Call the API to generate a response
154
  logger.info(f"querying GPT...")
155
  try:
156
- return openai.Completion.create(**completion_kwargs)
157
-
158
  except Exception as e:
159
  # log the error and return a generic response instead.
160
  logger.exception("Error connecting to OpenAI API. See traceback:")
161
- response = {"choices": [{"text": "We're having trouble connecting to OpenAI right now... Try again soon!"}]}
162
- return response
 
 
163
 
164
- def generate_response(self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str) -> str:
 
 
165
  """
166
  Generate a response based on the retrieved documents.
167
  """
168
  if len(matched_documents) == 0:
169
  # No matching documents were retrieved, return
170
- return unknown_prompt
171
 
172
  logger.info(f"Prompt: {prompt}")
173
  response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
174
- response_str = response["choices"][0]["text"]
175
- logger.info(f"GPT Response:\n{response_str}")
176
- return response_str
177
-
178
- def add_sources(self, response: str, matched_documents: pd.DataFrame, sep: str, format: str):
179
- """
180
- Add sources fromt the matched documents to the response.
181
- """
182
-
183
- urls = matched_documents.url.to_list()
184
- names = matched_documents.name.to_list()
185
- similarities = matched_documents.similarity.to_list()
186
-
187
- response += f"{sep}{sep}πŸ“ Here are the sources I used to answer your question:{sep}{sep}"
188
- for url, name, similarity in zip(urls, names, similarities):
189
- if format == "markdown":
190
- response += f"[πŸ”— {name}]({url}), relevance: {similarity:2.3f}{sep}"
191
- elif format == "html":
192
- response += f"<a href='{url}'>πŸ”— {name}</a>{sep}"
193
- elif format == "slack":
194
- response += f"<{url}|πŸ”— {name}>, relevance: {similarity:2.3f}{sep}"
195
  else:
196
- raise ValueError(f"{format} is not a valid URL format.")
197
 
198
- return response
199
 
200
  def check_response_relevance(
201
  self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
@@ -217,32 +216,7 @@ class Chatbot:
217
  # Likely that the answer is meaningful, add the top sources
218
  return score < unk_threshold
219
 
220
- def format_response(self, response: str, matched_documents: pd.DataFrame, text_after_response: str) -> str:
221
- """
222
- Format the response by adding the sources if necessary, and a disclaimer prompt.
223
- """
224
- sep = self.cfg.separator
225
-
226
- is_relevant = self.check_response_relevance(
227
- response=response,
228
- engine=self.cfg.embedding_model,
229
- unk_embedding=self.unk_embedding,
230
- unk_threshold=self.cfg.unknown_threshold,
231
- )
232
- if is_relevant:
233
- # Passes our relevance detection mechanism that the answer is meaningful, add the top sources
234
- response = self.add_sources(
235
- response=response,
236
- matched_documents=matched_documents,
237
- sep=self.cfg.separator,
238
- format=self.cfg.link_format,
239
- )
240
-
241
- response += f"{sep}{sep}{sep}{text_after_response}{sep}"
242
-
243
- return response
244
-
245
- def process_input(self, question: str) -> str:
246
  """
247
  Main function to process the input question and generate a formatted output.
248
  """
@@ -262,9 +236,6 @@ class Chatbot:
262
  text_before_prompt=self.cfg.text_before_prompt,
263
  text_before_documents=self.cfg.text_before_documents,
264
  )
265
- response = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
266
- formatted_output = self.format_response(
267
- response, matched_documents, text_after_response=self.cfg.text_after_response
268
- )
269
 
270
- return formatted_output
 
1
  import logging
2
  import os
3
  from dataclasses import dataclass, field
4
+ from typing import Iterable
5
 
6
  import numpy as np
7
  import openai
 
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
  from buster.docparser import read_documents
13
+ from buster.formatter import Formatter
14
+ from buster.formatter.base import Response, Source
15
+
16
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
 
153
  documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
154
  return text_before_documents + documents_str + text_before_prompt + question
155
 
156
+ def get_gpt_response(self, **completion_kwargs) -> Response:
157
  # Call the API to generate a response
158
  logger.info(f"querying GPT...")
159
  try:
160
+ response = openai.Completion.create(**completion_kwargs)
 
161
  except Exception as e:
162
  # log the error and return a generic response instead.
163
  logger.exception("Error connecting to OpenAI API. See traceback:")
164
+ return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
165
+
166
+ text = response["choices"][0]["text"]
167
+ return Response(text)
168
 
169
+ def generate_response(
170
+ self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str
171
+ ) -> tuple[Response, Iterable[Source]]:
172
  """
173
  Generate a response based on the retrieved documents.
174
  """
175
  if len(matched_documents) == 0:
176
  # No matching documents were retrieved, return
177
+ return Response(unknown_prompt), tuple()
178
 
179
  logger.info(f"Prompt: {prompt}")
180
  response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
181
+ if response:
182
+ logger.info(f"GPT Response:\n{response.text}")
183
+ relevant = self.check_response_relevance(
184
+ response=response.text,
185
+ engine=self.cfg.embedding_model,
186
+ unk_embedding=self.unk_embedding,
187
+ unk_threshold=self.cfg.unknown_threshold,
188
+ )
189
+ if relevant:
190
+ sources = (
191
+ Source(dct["name"], dct["url"], dct["similarity"])
192
+ for dct in matched_documents.to_dict(orient="records")
193
+ )
 
 
 
 
 
 
 
 
194
  else:
195
+ sources = tuple()
196
 
197
+ return response, sources
198
 
199
  def check_response_relevance(
200
  self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
 
216
  # Likely that the answer is meaningful, add the top sources
217
  return score < unk_threshold
218
 
219
+ def process_input(self, question: str, formatter: Formatter) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  """
221
  Main function to process the input question and generate a formatted output.
222
  """
 
236
  text_before_prompt=self.cfg.text_before_prompt,
237
  text_before_documents=self.cfg.text_before_documents,
238
  )
239
+ response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
 
 
 
240
 
241
+ return formatter(response, sources)
buster/formatter/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .base import Formatter
2
+ from .html import HTMLFormatter
3
+ from .markdown import MarkdownFormatter
4
+ from .slack import SlackFormatter
5
+
6
+ __all__ = [Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter]
buster/formatter/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, NamedTuple
3
+
4
+ # Should be from the `documents` module.
5
+ class Source(NamedTuple):
6
+ name: str
7
+ url: str
8
+ question_similarity: float
9
+ # TODO Add answer similarity.
10
+ # answer_similarity: float
11
+
12
+
13
+ # Should be from the `nlp` module.
14
+ @dataclass(slots=True)
15
+ class Response:
16
+ text: str
17
+ error: bool = False
18
+ error_msg: str | None = None
19
+
20
+ def __bool__(self) -> bool:
21
+ return not self.error
22
+
23
+
24
+ @dataclass
25
+ class Formatter:
26
+
27
+ source_template: str = "{source.name} (relevance: {source.question_similarity:2.3f})"
28
+ error_msg_template: str = "Something went wrong: {response.error_msg}"
29
+ error_fallback_template: str = "Something went very wrong."
30
+ sourced_answer_template: str = "{response.text}\n\nSources:\n{sources}\n\nBut what do I know, I'm a chatbot."
31
+ unsourced_answer_template: str = "{response.text}\n\nBut what do I know, I'm a chatbot."
32
+
33
+ def source_item(self, source: Source) -> str:
34
+ """Format a single source item."""
35
+ return self.source_template.format(source=source)
36
+
37
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
38
+ """Format sources into a list."""
39
+ items = [self.source_item(source) for source in sources]
40
+ if not items:
41
+ return None # No list needed.
42
+
43
+ return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
44
+
45
+ def error(self, response: Response) -> str:
46
+ """Format an error message."""
47
+ if response.error_msg:
48
+ return self.error_msg_template.format(response=response)
49
+ return self.error_fallback_template.format(response=response)
50
+
51
+ def answer(self, response: Response, sources: Iterable[Source]) -> str:
52
+ """Format an answer and its sources."""
53
+ sources_list = self.sources_list(sources)
54
+ if not sources_list:
55
+ return self.sourced_answer_template.format(response=response, sources=sources_list)
56
+
57
+ return self.unsourced_answer_template.format(response=response)
58
+
59
+ def __call__(self, response: Response, sources: Iterable[Source]) -> str:
60
+ """Format an answer and its sources, or an error message."""
61
+ if response:
62
+ return self.answer(response, sources)
63
+ return self.error(response)
buster/formatter/html.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import html
3
+ from typing import Iterable
4
+ from buster.formatter.base import Formatter, Response, Source
5
+
6
+
7
+ @dataclass
8
+ class HTMLFormatter(Formatter):
9
+ """Format the answer in HTML."""
10
+
11
+ source_template: str = """<li><a href='{source.url}'>πŸ”— {source.name}</a></li>"""
12
+ error_msg_template: str = """<div class="error">Something went wrong:\n<p>{response.error_msg}</p></div>"""
13
+ error_fallback_template: str = """<div class="error">Something went very wrong.</div>"""
14
+ sourced_answer_template: str = (
15
+ """<div class="answer"><p>{response.text}</p></div>\n"""
16
+ """<div class="sources>πŸ“ Here are the sources I used to answer your question:\n"""
17
+ """<ol>\n{sources}</ol></div>\n"""
18
+ """<div class="footer">I'm a chatbot, bleep bloop.</div>"""
19
+ )
20
+ unsourced_answer_template: str = (
21
+ """<div class="answer">{response.text}</div>\n<div class="footer">I'm a chatbot, bleep bloop.</div>"""
22
+ )
23
+
24
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
25
+ """Format sources into a list."""
26
+ items = [self.source_item(source) for source in sources]
27
+ if not items:
28
+ return None # No list needed.
29
+
30
+ return "\n".join(items)
31
+
32
+ def __call__(self, response: Response, sources: Iterable[Source]) -> str:
33
+ # Escape any html in the text.
34
+ response = Response(
35
+ html.escape(response.text) if response.text else response.text,
36
+ response.error,
37
+ html.escape(response.error_msg) if response.error_msg else response.error_msg,
38
+ )
39
+ sources = (Source(html.escape(source.name), source.url, source.question_similarity) for source in sources)
40
+ return super().__call__(response, sources)
buster/formatter/markdown.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable
3
+ from buster.formatter.base import Formatter, Source
4
+
5
+
6
+ @dataclass
7
+ class MarkdownFormatter(Formatter):
8
+ """Format the answer in markdown."""
9
+
10
+ source_template: str = """[πŸ”— {source.name}]({source.url}), relevance: {source.question_similarity:2.3f}"""
11
+ error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
12
+ error_fallback_template: str = """Something went very wrong."""
13
+ sourced_answer_template: str = (
14
+ """{response.text}\n\n"""
15
+ """πŸ“ Here are the sources I used to answer your question:\n"""
16
+ """{sources}\n\n"""
17
+ """I'm a chatbot, bleep bloop."""
18
+ )
19
+ unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
20
+
21
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
22
+ """Format sources into a list."""
23
+ items = [self.source_item(source) for source in sources]
24
+ if not items:
25
+ return None # No list needed.
26
+
27
+ return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
buster/formatter/slack.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Iterable
3
+ from buster.formatter.base import Formatter, Source
4
+
5
+
6
+ @dataclass
7
+ class SlackFormatter(Formatter):
8
+ """Format the answer for Slack."""
9
+
10
+ source_template: str = """<{source.url}|πŸ”— {source.name}>, relevance: {source.question_similarity:2.3f}"""
11
+ error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
12
+ error_fallback_template: str = """Something went very wrong."""
13
+ sourced_answer_template: str = (
14
+ """{response.text}\n\n"""
15
+ """πŸ“ Here are the sources I used to answer your question:\n"""
16
+ """{sources}\n\n"""
17
+ """I'm a chatbot, bleep bloop."""
18
+ )
19
+ unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
20
+
21
+ def sources_list(self, sources: Iterable[Source]) -> str | None:
22
+ """Format sources into a list."""
23
+ items = [self.source_item(source) for source in sources]
24
+ if not items:
25
+ return None # No list needed.
26
+
27
+ return "\n".join(items)