jerpint commited on
Commit
6a4ac5a
·
1 Parent(s): 9143594

big spring cleaning

Browse files
buster/busterbot.py CHANGED
@@ -7,17 +7,19 @@ import pandas as pd
7
  from openai.embeddings_utils import cosine_similarity, get_embedding
8
 
9
  from buster.completers import get_completer
10
- from buster.formatter import (
11
- Response,
12
- ResponseFormatter,
13
- Source,
14
- response_formatter_factory,
15
- )
16
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
20
 
 
 
 
 
 
 
21
  @dataclass
22
  class BusterConfig:
23
  """Configuration object for a chatbot.
@@ -36,7 +38,7 @@ class BusterConfig:
36
  source: the source of the document to consider
37
  """
38
 
39
- documents_file: str = "buster/data/document_embeddings.tar.gz"
40
  embedding_model: str = "text-embedding-ada-002"
41
  top_k: int = 3
42
  thresh: float = 0.7
@@ -58,9 +60,8 @@ class BusterConfig:
58
  },
59
  }
60
  )
61
- response_format: str = "slack"
62
  unknown_prompt: str = "I Don't know how to answer your question."
63
- response_footnote: str = "I'm a bot 🤖 and not always perfect."
64
  source: str = ""
65
 
66
 
@@ -91,9 +92,13 @@ class Buster:
91
  self.cfg = cfg
92
  self.completer = get_completer(cfg.completer_cfg)
93
  self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model)
94
- self.response_formatter = response_formatter_factory(
95
- format=self.cfg.response_format, response_footnote=self.cfg.response_footnote
 
 
 
96
  )
 
97
  logger.info(f"Config Updated.")
98
 
99
  @lru_cache
@@ -129,38 +134,8 @@ class Buster:
129
 
130
  return matched_documents
131
 
132
- def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
133
- # gather the documents in one large plaintext variable
134
- documents_list = matched_documents.content.to_list()
135
- documents_str = ""
136
- for idx, doc in enumerate(documents_list):
137
- documents_str += f"<DOCUMENT> {doc} <\DOCUMENT>"
138
-
139
- # truncate the documents to fit
140
- # TODO: increase to actual token count
141
- word_count = len(documents_str.split(" "))
142
- if word_count > max_words:
143
- logger.info("truncating documents to fit...")
144
- documents_str = " ".join(documents_str.split(" ")[0:max_words])
145
- logger.info(f"Documents after truncation: {documents_str}")
146
-
147
- return documents_str
148
-
149
- def add_sources(
150
- self,
151
- matched_documents: pd.DataFrame,
152
- ):
153
- sources = (
154
- Source(
155
- source=dct["source"], title=dct["title"], url=dct["url"], question_similarity=dct["similarity"] * 100
156
- )
157
- for dct in matched_documents.to_dict(orient="records")
158
- )
159
-
160
- return sources
161
-
162
  def check_response_relevance(
163
- self, completion: str, engine: str, unk_embedding: np.array, unk_threshold: float
164
  ) -> bool:
165
  """Check to see if a response is relevant to the chatbot's knowledge or not.
166
 
@@ -170,7 +145,7 @@ class Buster:
170
  set the unk_threshold to 0 to essentially turn off this feature.
171
  """
172
  response_embedding = self.get_embedding(
173
- completion,
174
  engine=engine,
175
  )
176
  score = cosine_similarity(response_embedding, unk_embedding)
@@ -179,7 +154,7 @@ class Buster:
179
  # Likely that the answer is meaningful, add the top sources
180
  return score < unk_threshold
181
 
182
- def process_input(self, user_input: str, formatter: ResponseFormatter = None) -> str:
183
  """
184
  Main function to process the input question and generate a formatted output.
185
  """
@@ -199,28 +174,29 @@ class Buster:
199
  )
200
 
201
  if len(matched_documents) == 0:
202
- response = Response(self.cfg.unknown_prompt)
203
- sources = tuple()
204
- return self.response_formatter(response, sources)
205
-
206
- # generate a completion
207
- documents: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
208
- response: Response = self.completer.generate_response(user_input, documents)
209
- logger.info(f"GPT Response:\n{response.text}")
210
 
211
- sources = self.add_sources(matched_documents)
 
 
 
212
 
213
  # check for relevance
214
  relevant = self.check_response_relevance(
215
- completion=response.text,
216
  engine=self.cfg.embedding_model,
217
  unk_embedding=self.unk_embedding,
218
  unk_threshold=self.cfg.unknown_threshold,
219
  )
220
  if not relevant:
 
221
  # answer generated was the chatbot saying it doesn't know how to answer
222
- # override completion with generic "I don't know"
223
- response = Response(text=self.cfg.unknown_prompt)
224
- sources = tuple()
225
 
226
- return self.response_formatter(response, sources)
 
 
7
  from openai.embeddings_utils import cosine_similarity, get_embedding
8
 
9
  from buster.completers import get_completer
10
+ from buster.completers.base import Completion
11
+ from buster.formatters.prompts import SystemPromptFormatter
 
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
  logging.basicConfig(level=logging.INFO)
15
 
16
 
17
+ @dataclass(slots=True)
18
+ class Response:
19
+ completion: Completion
20
+ matched_documents: pd.DataFrame | None = None
21
+
22
+
23
  @dataclass
24
  class BusterConfig:
25
  """Configuration object for a chatbot.
 
38
  source: the source of the document to consider
39
  """
40
 
41
+ documents_file: str = ""
42
  embedding_model: str = "text-embedding-ada-002"
43
  top_k: int = 3
44
  thresh: float = 0.7
 
60
  },
61
  }
62
  )
 
63
  unknown_prompt: str = "I Don't know how to answer your question."
64
+ response_format: str = "slack"
65
  source: str = ""
66
 
67
 
 
92
  self.cfg = cfg
93
  self.completer = get_completer(cfg.completer_cfg)
94
  self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model)
95
+
96
+ self.prompt_formatter = SystemPromptFormatter(
97
+ text_before_docs=self.cfg.completer_cfg["text_before_documents"],
98
+ text_after_docs=self.cfg.completer_cfg["text_before_prompt"],
99
+ max_words=self.cfg.max_words,
100
  )
101
+
102
  logger.info(f"Config Updated.")
103
 
104
  @lru_cache
 
134
 
135
  return matched_documents
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def check_response_relevance(
138
+ self, completion_text: str, engine: str, unk_embedding: np.array, unk_threshold: float
139
  ) -> bool:
140
  """Check to see if a response is relevant to the chatbot's knowledge or not.
141
 
 
145
  set the unk_threshold to 0 to essentially turn off this feature.
146
  """
147
  response_embedding = self.get_embedding(
148
+ completion_text,
149
  engine=engine,
150
  )
151
  score = cosine_similarity(response_embedding, unk_embedding)
 
154
  # Likely that the answer is meaningful, add the top sources
155
  return score < unk_threshold
156
 
157
+ def process_input(self, user_input: str) -> str:
158
  """
159
  Main function to process the input question and generate a formatted output.
160
  """
 
174
  )
175
 
176
  if len(matched_documents) == 0:
177
+ logger.warning("No documents found...")
178
+ completion = Completion(text="No documents found.")
179
+ matched_documents = pd.DataFrame(columns=matched_documents.columns)
180
+ response = Response(completion=completion, matched_documents=matched_documents)
181
+ return response
 
 
 
182
 
183
+ # prepare the prompt
184
+ system_prompt = self.prompt_formatter.format(matched_documents)
185
+ completion: Completion = self.completer.generate_response(user_input=user_input, system_prompt=system_prompt)
186
+ logger.info(f"GPT Response:\n{completion.text}")
187
 
188
  # check for relevance
189
  relevant = self.check_response_relevance(
190
+ completion_text=completion.text,
191
  engine=self.cfg.embedding_model,
192
  unk_embedding=self.unk_embedding,
193
  unk_threshold=self.cfg.unknown_threshold,
194
  )
195
  if not relevant:
196
+ matched_documents = pd.DataFrame(columns=matched_documents.columns)
197
  # answer generated was the chatbot saying it doesn't know how to answer
198
+ # uncomment override completion with unknown prompt
199
+ # completion = Completion(text=self.cfg.unknown_prompt)
 
200
 
201
+ response = Response(completion=completion, matched_documents=matched_documents)
202
+ return response
buster/completers/base.py CHANGED
@@ -19,12 +19,14 @@ if promptlayer_api_key:
19
  openai = promptlayer.openai
20
  openai.api_key = os.environ.get("OPENAI_API_KEY")
21
 
 
22
  @dataclass(slots=True)
23
  class Completion:
24
  text: str
25
  error: bool = False
26
  error_msg: str | None = None
27
 
 
28
  class Completer(ABC):
29
  def __init__(self, completion_kwargs: dict):
30
  self.completion_kwargs = completion_kwargs
 
19
  openai = promptlayer.openai
20
  openai.api_key = os.environ.get("OPENAI_API_KEY")
21
 
22
+
23
  @dataclass(slots=True)
24
  class Completion:
25
  text: str
26
  error: bool = False
27
  error_msg: str | None = None
28
 
29
+
30
  class Completer(ABC):
31
  def __init__(self, completion_kwargs: dict):
32
  self.completion_kwargs = completion_kwargs
buster/formatters/prompts.py CHANGED
@@ -6,6 +6,7 @@ import pandas as pd
6
  logger = logging.getLogger(__name__)
7
  logging.basicConfig(level=logging.INFO)
8
 
 
9
  @dataclass
10
  class SystemPromptFormatter:
11
  text_before_docs: str = ""
@@ -38,4 +39,4 @@ class SystemPromptFormatter:
38
  """
39
  documents = self.format_documents(matched_documents, max_words=self.max_words)
40
  system_prompt = self.text_before_docs + documents + self.text_after_docs
41
- return system_prompt
 
6
  logger = logging.getLogger(__name__)
7
  logging.basicConfig(level=logging.INFO)
8
 
9
+
10
  @dataclass
11
  class SystemPromptFormatter:
12
  text_before_docs: str = ""
 
39
  """
40
  documents = self.format_documents(matched_documents, max_words=self.max_words)
41
  system_prompt = self.text_before_docs + documents + self.text_after_docs
42
+ return system_prompt