Marc-Antoine Rondeau commited on
Commit
cb9fd37
·
1 Parent(s): 0e4a27a

Patch formatter creation for backward compatibility.

Browse files
Files changed (1) hide show
  1. buster/chatbot.py +10 -2
buster/chatbot.py CHANGED
@@ -10,10 +10,13 @@ import promptlayer
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
  from buster.docparser import read_documents
13
- from buster.formatter import Formatter
14
  from buster.formatter.base import Response, Source
15
 
16
 
 
 
 
17
  logger = logging.getLogger(__name__)
18
  logging.basicConfig(level=logging.INFO)
19
 
@@ -216,11 +219,16 @@ class Chatbot:
216
  # Likely that the answer is meaningful, add the top sources
217
  return score < unk_threshold
218
 
219
- def process_input(self, question: str, formatter: Formatter) -> str:
220
  """
221
  Main function to process the input question and generate a formatted output.
222
  """
223
 
 
 
 
 
 
224
  logger.info(f"User Question:\n{question}")
225
 
226
  matched_documents = self.rank_documents(
 
10
  from openai.embeddings_utils import cosine_similarity, get_embedding
11
 
12
  from buster.docparser import read_documents
13
+ from buster.formatter import Formatter, SlackFormatter, HTMLFormatter, MarkdownFormatter
14
  from buster.formatter.base import Response, Source
15
 
16
 
17
+ FORMATTERS = {"text": Formatter, "slack": SlackFormatter, "html": HTMLFormatter, "markdown": MarkdownFormatter}
18
+
19
+
20
  logger = logging.getLogger(__name__)
21
  logging.basicConfig(level=logging.INFO)
22
 
 
219
  # Likely that the answer is meaningful, add the top sources
220
  return score < unk_threshold
221
 
222
+ def process_input(self, question: str, formatter: Formatter = None) -> str:
223
  """
224
  Main function to process the input question and generate a formatted output.
225
  """
226
 
227
+ if formatter is None and self.cfg.link_format not in FORMATTERS:
228
+ raise ValueError(f"Unknown link format {self.cfg.link_format}")
229
+ elif formatter is None:
230
+ formatter = FORMATTERS[self.cfg.link_format]()
231
+
232
  logger.info(f"User Question:\n{question}")
233
 
234
  matched_documents = self.rank_documents(