Spaces:
Runtime error
Runtime error
Marc-Antoine Rondeau
commited on
Commit
Β·
0e4a27a
1
Parent(s):
8b5fed9
Initial Formatters
Browse files- buster/chatbot.py +32 -61
- buster/formatter/__init__.py +6 -0
- buster/formatter/base.py +63 -0
- buster/formatter/html.py +40 -0
- buster/formatter/markdown.py +27 -0
- buster/formatter/slack.py +27 -0
buster/chatbot.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import logging
|
2 |
import os
|
3 |
from dataclasses import dataclass, field
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
import openai
|
@@ -9,6 +10,9 @@ import promptlayer
|
|
9 |
from openai.embeddings_utils import cosine_similarity, get_embedding
|
10 |
|
11 |
from buster.docparser import read_documents
|
|
|
|
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
logging.basicConfig(level=logging.INFO)
|
@@ -149,53 +153,48 @@ class Chatbot:
|
|
149 |
documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
|
150 |
return text_before_documents + documents_str + text_before_prompt + question
|
151 |
|
152 |
-
def get_gpt_response(self, **completion_kwargs):
|
153 |
# Call the API to generate a response
|
154 |
logger.info(f"querying GPT...")
|
155 |
try:
|
156 |
-
|
157 |
-
|
158 |
except Exception as e:
|
159 |
# log the error and return a generic response instead.
|
160 |
logger.exception("Error connecting to OpenAI API. See traceback:")
|
161 |
-
|
162 |
-
|
|
|
|
|
163 |
|
164 |
-
def generate_response(
|
|
|
|
|
165 |
"""
|
166 |
Generate a response based on the retrieved documents.
|
167 |
"""
|
168 |
if len(matched_documents) == 0:
|
169 |
# No matching documents were retrieved, return
|
170 |
-
return unknown_prompt
|
171 |
|
172 |
logger.info(f"Prompt: {prompt}")
|
173 |
response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
response += f"{sep}{sep}π Here are the sources I used to answer your question:{sep}{sep}"
|
188 |
-
for url, name, similarity in zip(urls, names, similarities):
|
189 |
-
if format == "markdown":
|
190 |
-
response += f"[π {name}]({url}), relevance: {similarity:2.3f}{sep}"
|
191 |
-
elif format == "html":
|
192 |
-
response += f"<a href='{url}'>π {name}</a>{sep}"
|
193 |
-
elif format == "slack":
|
194 |
-
response += f"<{url}|π {name}>, relevance: {similarity:2.3f}{sep}"
|
195 |
else:
|
196 |
-
|
197 |
|
198 |
-
return response
|
199 |
|
200 |
def check_response_relevance(
|
201 |
self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
|
@@ -217,32 +216,7 @@ class Chatbot:
|
|
217 |
# Likely that the answer is meaningful, add the top sources
|
218 |
return score < unk_threshold
|
219 |
|
220 |
-
def
|
221 |
-
"""
|
222 |
-
Format the response by adding the sources if necessary, and a disclaimer prompt.
|
223 |
-
"""
|
224 |
-
sep = self.cfg.separator
|
225 |
-
|
226 |
-
is_relevant = self.check_response_relevance(
|
227 |
-
response=response,
|
228 |
-
engine=self.cfg.embedding_model,
|
229 |
-
unk_embedding=self.unk_embedding,
|
230 |
-
unk_threshold=self.cfg.unknown_threshold,
|
231 |
-
)
|
232 |
-
if is_relevant:
|
233 |
-
# Passes our relevance detection mechanism that the answer is meaningful, add the top sources
|
234 |
-
response = self.add_sources(
|
235 |
-
response=response,
|
236 |
-
matched_documents=matched_documents,
|
237 |
-
sep=self.cfg.separator,
|
238 |
-
format=self.cfg.link_format,
|
239 |
-
)
|
240 |
-
|
241 |
-
response += f"{sep}{sep}{sep}{text_after_response}{sep}"
|
242 |
-
|
243 |
-
return response
|
244 |
-
|
245 |
-
def process_input(self, question: str) -> str:
|
246 |
"""
|
247 |
Main function to process the input question and generate a formatted output.
|
248 |
"""
|
@@ -262,9 +236,6 @@ class Chatbot:
|
|
262 |
text_before_prompt=self.cfg.text_before_prompt,
|
263 |
text_before_documents=self.cfg.text_before_documents,
|
264 |
)
|
265 |
-
response = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
|
266 |
-
formatted_output = self.format_response(
|
267 |
-
response, matched_documents, text_after_response=self.cfg.text_after_response
|
268 |
-
)
|
269 |
|
270 |
-
return
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
from dataclasses import dataclass, field
|
4 |
+
from typing import Iterable
|
5 |
|
6 |
import numpy as np
|
7 |
import openai
|
|
|
10 |
from openai.embeddings_utils import cosine_similarity, get_embedding
|
11 |
|
12 |
from buster.docparser import read_documents
|
13 |
+
from buster.formatter import Formatter
|
14 |
+
from buster.formatter.base import Response, Source
|
15 |
+
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
logging.basicConfig(level=logging.INFO)
|
|
|
153 |
documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
|
154 |
return text_before_documents + documents_str + text_before_prompt + question
|
155 |
|
156 |
+
def get_gpt_response(self, **completion_kwargs) -> Response:
|
157 |
# Call the API to generate a response
|
158 |
logger.info(f"querying GPT...")
|
159 |
try:
|
160 |
+
response = openai.Completion.create(**completion_kwargs)
|
|
|
161 |
except Exception as e:
|
162 |
# log the error and return a generic response instead.
|
163 |
logger.exception("Error connecting to OpenAI API. See traceback:")
|
164 |
+
return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")
|
165 |
+
|
166 |
+
text = response["choices"][0]["text"]
|
167 |
+
return Response(text)
|
168 |
|
169 |
+
def generate_response(
|
170 |
+
self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str
|
171 |
+
) -> tuple[Response, Iterable[Source]]:
|
172 |
"""
|
173 |
Generate a response based on the retrieved documents.
|
174 |
"""
|
175 |
if len(matched_documents) == 0:
|
176 |
# No matching documents were retrieved, return
|
177 |
+
return Response(unknown_prompt), tuple()
|
178 |
|
179 |
logger.info(f"Prompt: {prompt}")
|
180 |
response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
|
181 |
+
if response:
|
182 |
+
logger.info(f"GPT Response:\n{response.text}")
|
183 |
+
relevant = self.check_response_relevance(
|
184 |
+
response=response.text,
|
185 |
+
engine=self.cfg.embedding_model,
|
186 |
+
unk_embedding=self.unk_embedding,
|
187 |
+
unk_threshold=self.cfg.unknown_threshold,
|
188 |
+
)
|
189 |
+
if relevant:
|
190 |
+
sources = (
|
191 |
+
Source(dct["name"], dct["url"], dct["similarity"])
|
192 |
+
for dct in matched_documents.to_dict(orient="records")
|
193 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
else:
|
195 |
+
sources = tuple()
|
196 |
|
197 |
+
return response, sources
|
198 |
|
199 |
def check_response_relevance(
|
200 |
self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
|
|
|
216 |
# Likely that the answer is meaningful, add the top sources
|
217 |
return score < unk_threshold
|
218 |
|
219 |
+
def process_input(self, question: str, formatter: Formatter) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
"""
|
221 |
Main function to process the input question and generate a formatted output.
|
222 |
"""
|
|
|
236 |
text_before_prompt=self.cfg.text_before_prompt,
|
237 |
text_before_documents=self.cfg.text_before_documents,
|
238 |
)
|
239 |
+
response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)
|
|
|
|
|
|
|
240 |
|
241 |
+
return formatter(response, sources)
|
buster/formatter/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import Formatter
|
2 |
+
from .html import HTMLFormatter
|
3 |
+
from .markdown import MarkdownFormatter
|
4 |
+
from .slack import SlackFormatter
|
5 |
+
|
6 |
+
__all__ = [Formatter, HTMLFormatter, MarkdownFormatter, SlackFormatter]
|
buster/formatter/base.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Iterable, NamedTuple
|
3 |
+
|
4 |
+
# Should be from the `documents` module.
|
5 |
+
class Source(NamedTuple):
|
6 |
+
name: str
|
7 |
+
url: str
|
8 |
+
question_similarity: float
|
9 |
+
# TODO Add answer similarity.
|
10 |
+
# answer_similarity: float
|
11 |
+
|
12 |
+
|
13 |
+
# Should be from the `nlp` module.
|
14 |
+
@dataclass(slots=True)
|
15 |
+
class Response:
|
16 |
+
text: str
|
17 |
+
error: bool = False
|
18 |
+
error_msg: str | None = None
|
19 |
+
|
20 |
+
def __bool__(self) -> bool:
|
21 |
+
return not self.error
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class Formatter:
|
26 |
+
|
27 |
+
source_template: str = "{source.name} (relevance: {source.question_similarity:2.3f})"
|
28 |
+
error_msg_template: str = "Something went wrong: {response.error_msg}"
|
29 |
+
error_fallback_template: str = "Something went very wrong."
|
30 |
+
sourced_answer_template: str = "{response.text}\n\nSources:\n{sources}\n\nBut what do I know, I'm a chatbot."
|
31 |
+
unsourced_answer_template: str = "{response.text}\n\nBut what do I know, I'm a chatbot."
|
32 |
+
|
33 |
+
def source_item(self, source: Source) -> str:
|
34 |
+
"""Format a single source item."""
|
35 |
+
return self.source_template.format(source=source)
|
36 |
+
|
37 |
+
def sources_list(self, sources: Iterable[Source]) -> str | None:
|
38 |
+
"""Format sources into a list."""
|
39 |
+
items = [self.source_item(source) for source in sources]
|
40 |
+
if not items:
|
41 |
+
return None # No list needed.
|
42 |
+
|
43 |
+
return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
|
44 |
+
|
45 |
+
def error(self, response: Response) -> str:
|
46 |
+
"""Format an error message."""
|
47 |
+
if response.error_msg:
|
48 |
+
return self.error_msg_template.format(response=response)
|
49 |
+
return self.error_fallback_template.format(response=response)
|
50 |
+
|
51 |
+
def answer(self, response: Response, sources: Iterable[Source]) -> str:
|
52 |
+
"""Format an answer and its sources."""
|
53 |
+
sources_list = self.sources_list(sources)
|
54 |
+
if not sources_list:
|
55 |
+
return self.sourced_answer_template.format(response=response, sources=sources_list)
|
56 |
+
|
57 |
+
return self.unsourced_answer_template.format(response=response)
|
58 |
+
|
59 |
+
def __call__(self, response: Response, sources: Iterable[Source]) -> str:
|
60 |
+
"""Format an answer and its sources, or an error message."""
|
61 |
+
if response:
|
62 |
+
return self.answer(response, sources)
|
63 |
+
return self.error(response)
|
buster/formatter/html.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import html
|
3 |
+
from typing import Iterable
|
4 |
+
from buster.formatter.base import Formatter, Response, Source
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class HTMLFormatter(Formatter):
|
9 |
+
"""Format the answer in HTML."""
|
10 |
+
|
11 |
+
source_template: str = """<li><a href='{source.url}'>π {source.name}</a></li>"""
|
12 |
+
error_msg_template: str = """<div class="error">Something went wrong:\n<p>{response.error_msg}</p></div>"""
|
13 |
+
error_fallback_template: str = """<div class="error">Something went very wrong.</div>"""
|
14 |
+
sourced_answer_template: str = (
|
15 |
+
"""<div class="answer"><p>{response.text}</p></div>\n"""
|
16 |
+
"""<div class="sources>π Here are the sources I used to answer your question:\n"""
|
17 |
+
"""<ol>\n{sources}</ol></div>\n"""
|
18 |
+
"""<div class="footer">I'm a chatbot, bleep bloop.</div>"""
|
19 |
+
)
|
20 |
+
unsourced_answer_template: str = (
|
21 |
+
"""<div class="answer">{response.text}</div>\n<div class="footer">I'm a chatbot, bleep bloop.</div>"""
|
22 |
+
)
|
23 |
+
|
24 |
+
def sources_list(self, sources: Iterable[Source]) -> str | None:
|
25 |
+
"""Format sources into a list."""
|
26 |
+
items = [self.source_item(source) for source in sources]
|
27 |
+
if not items:
|
28 |
+
return None # No list needed.
|
29 |
+
|
30 |
+
return "\n".join(items)
|
31 |
+
|
32 |
+
def __call__(self, response: Response, sources: Iterable[Source]) -> str:
|
33 |
+
# Escape any html in the text.
|
34 |
+
response = Response(
|
35 |
+
html.escape(response.text) if response.text else response.text,
|
36 |
+
response.error,
|
37 |
+
html.escape(response.error_msg) if response.error_msg else response.error_msg,
|
38 |
+
)
|
39 |
+
sources = (Source(html.escape(source.name), source.url, source.question_similarity) for source in sources)
|
40 |
+
return super().__call__(response, sources)
|
buster/formatter/markdown.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Iterable
|
3 |
+
from buster.formatter.base import Formatter, Source
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class MarkdownFormatter(Formatter):
|
8 |
+
"""Format the answer in markdown."""
|
9 |
+
|
10 |
+
source_template: str = """[π {source.name}]({source.url}), relevance: {source.question_similarity:2.3f}"""
|
11 |
+
error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
|
12 |
+
error_fallback_template: str = """Something went very wrong."""
|
13 |
+
sourced_answer_template: str = (
|
14 |
+
"""{response.text}\n\n"""
|
15 |
+
"""π Here are the sources I used to answer your question:\n"""
|
16 |
+
"""{sources}\n\n"""
|
17 |
+
"""I'm a chatbot, bleep bloop."""
|
18 |
+
)
|
19 |
+
unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
|
20 |
+
|
21 |
+
def sources_list(self, sources: Iterable[Source]) -> str | None:
|
22 |
+
"""Format sources into a list."""
|
23 |
+
items = [self.source_item(source) for source in sources]
|
24 |
+
if not items:
|
25 |
+
return None # No list needed.
|
26 |
+
|
27 |
+
return "\n".join(f"{ind}. {item}" for ind, item in enumerate(items, 1))
|
buster/formatter/slack.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Iterable
|
3 |
+
from buster.formatter.base import Formatter, Source
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class SlackFormatter(Formatter):
|
8 |
+
"""Format the answer for Slack."""
|
9 |
+
|
10 |
+
source_template: str = """<{source.url}|π {source.name}>, relevance: {source.question_similarity:2.3f}"""
|
11 |
+
error_msg_template: str = """Something went wrong:\n{response.error_msg}"""
|
12 |
+
error_fallback_template: str = """Something went very wrong."""
|
13 |
+
sourced_answer_template: str = (
|
14 |
+
"""{response.text}\n\n"""
|
15 |
+
"""π Here are the sources I used to answer your question:\n"""
|
16 |
+
"""{sources}\n\n"""
|
17 |
+
"""I'm a chatbot, bleep bloop."""
|
18 |
+
)
|
19 |
+
unsourced_answer_template: str = """{response.text}\n\nI'm a chatbot, bleep bloop."""
|
20 |
+
|
21 |
+
def sources_list(self, sources: Iterable[Source]) -> str | None:
|
22 |
+
"""Format sources into a list."""
|
23 |
+
items = [self.source_item(source) for source in sources]
|
24 |
+
if not items:
|
25 |
+
return None # No list needed.
|
26 |
+
|
27 |
+
return "\n".join(items)
|