File size: 8,459 Bytes
e112463
fa9ac7e
c6dd20e
e112463
 
 
 
 
c8a1687
5b7d0e6
 
 
 
 
 
cb9fd37
e112463
 
 
05dabf4
1ea867d
c8a1687
1ea867d
 
 
 
 
 
6a09028
1ea867d
 
5b7d0e6
1ea867d
 
5b7d0e6
c6dd20e
1ea867d
 
71e7dd8
1ea867d
 
 
6a09028
 
c8a1687
 
1ea867d
c8a1687
 
 
 
 
 
 
 
 
 
 
1ea867d
 
5b7d0e6
1ea867d
5b7d0e6
c6dd20e
 
 
06bca0c
1ea867d
 
c8a1687
06bca0c
c6dd20e
fa9ac7e
c6dd20e
 
06bca0c
c6dd20e
 
 
 
 
 
 
 
 
 
5b7d0e6
c6dd20e
 
 
 
 
 
5b7d0e6
 
 
c6dd20e
e112463
c6dd20e
 
 
 
e112463
fa9ac7e
 
 
6a09028
 
 
c6dd20e
fa9ac7e
 
 
 
f97aa81
c6dd20e
fa9ac7e
 
 
06bca0c
f97aa81
fa9ac7e
 
e112463
fa9ac7e
 
 
 
 
 
e112463
6a09028
 
1f22b14
c6dd20e
 
 
6a09028
 
 
 
 
fa9ac7e
6a09028
 
e112463
6a09028
e112463
c8a1687
6a09028
 
c8a1687
 
c6dd20e
 
 
 
c8a1687
 
 
fa9ac7e
6a09028
c8a1687
6a09028
 
 
 
 
 
 
 
c6dd20e
c8a1687
6a09028
 
 
 
 
 
 
 
c8a1687
fa9ac7e
 
 
 
c8a1687
fa9ac7e
1a12df2
c8a1687
 
1a12df2
6a09028
c8a1687
6a09028
 
 
c6dd20e
6a09028
c8a1687
 
c6dd20e
c8a1687
 
 
 
 
c6dd20e
463196d
 
 
c8a1687
 
 
 
 
 
 
6a09028
c8a1687
 
 
 
 
fa9ac7e
5b7d0e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import logging
from dataclasses import dataclass, field
from functools import lru_cache

import numpy as np
import pandas as pd
from openai.embeddings_utils import cosine_similarity, get_embedding

from buster.completers import get_completer
from buster.formatter import (
    Response,
    ResponseFormatter,
    Source,
    response_formatter_factory,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class BusterConfig:
    """Configuration object for a chatbot.

    documents_csv: Path to the csv file containing the documents and their embeddings.
    embedding_model: OpenAI model to use to get embeddings.
    top_k: Max number of documents to retrieve, ordered by cosine similarity
    thresh: threshold for cosine similarity to be considered
    max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
    completion_kwargs: kwargs for the OpenAI.Completion() method
    separator: the separator to use, can be either "\n" or <p> depending on rendering.
    response_format: the type of format to render links with, e.g. slack or markdown
    unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
    text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
    reponse_footnote: Generic response to add the the chatbot's reply.
    source: the source of the document to consider
    """

    documents_file: str = "buster/data/document_embeddings.tar.gz"
    embedding_model: str = "text-embedding-ada-002"
    top_k: int = 3
    thresh: float = 0.7
    max_words: int = 3000
    unknown_threshold: float = 0.9  # set to 0 to deactivate
    completer_cfg: dict = field(
        # TODO: Put all this in its own config with sane defaults?
        default_factory=lambda: {
            "name": "GPT3",
            "text_before_documents": "You are a chatbot answering questions.\n",
            "text_before_prompt": "Answer the following question:\n",
            "completion_kwargs": {
                "engine": "text-davinci-003",
                "max_tokens": 200,
                "temperature": None,
                "top_p": None,
                "frequency_penalty": 1,
                "presence_penalty": 1,
            },
        }
    )
    response_format: str = "slack"
    unknown_prompt: str = "I Don't know how to answer your question."
    response_footnote: str = "I'm a bot 🤖 and not always perfect."
    source: str = ""


from buster.retriever import Retriever


class Buster:
    def __init__(self, cfg: BusterConfig, retriever: Retriever):
        self._unk_embedding = None
        self.cfg = cfg
        self.update_cfg(cfg)

        self.retriever = retriever

    @property
    def unk_embedding(self):
        return self._unk_embedding

    @unk_embedding.setter
    def unk_embedding(self, embedding):
        logger.info("Setting new UNK embedding...")
        self._unk_embedding = embedding
        return self._unk_embedding

    def update_cfg(self, cfg: BusterConfig):
        """Every time we set a new config, we update the things that need to be updated."""
        logger.info(f"Updating config to {cfg.source}:\n{cfg}")
        self.cfg = cfg
        self.completer = get_completer(cfg.completer_cfg)
        self.unk_embedding = self.get_embedding(self.cfg.unknown_prompt, engine=self.cfg.embedding_model)
        self.response_formatter = response_formatter_factory(
            format=self.cfg.response_format, response_footnote=self.cfg.response_footnote
        )
        logger.info(f"Config Updated.")

    @lru_cache
    def get_embedding(self, query: str, engine: str):
        logger.info("generating embedding")
        return get_embedding(query, engine=engine)

    def rank_documents(
        self,
        query: str,
        top_k: float,
        thresh: float,
        engine: str,
        source: str,
    ) -> pd.DataFrame:
        """
        Compare the question to the series of documents and return the best matching documents.
        """

        query_embedding = self.get_embedding(
            query,
            engine=engine,
        )
        matched_documents = self.retriever.retrieve(query_embedding, top_k=top_k, source=source)

        # log matched_documents to the console
        logger.info(f"matched documents before thresh: {matched_documents}")

        # filter out matched_documents using a threshold
        if thresh:
            matched_documents = matched_documents[matched_documents.similarity > thresh]
            logger.info(f"matched documents after thresh: {matched_documents}")

        return matched_documents

    def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
        # gather the documents in one large plaintext variable
        documents_list = matched_documents.content.to_list()
        documents_str = ""
        for idx, doc in enumerate(documents_list):
            documents_str += f"<DOCUMENT> {doc} <\DOCUMENT>"

        # truncate the documents to fit
        # TODO: increase to actual token count
        word_count = len(documents_str.split(" "))
        if word_count > max_words:
            logger.info("truncating documents to fit...")
            documents_str = " ".join(documents_str.split(" ")[0:max_words])
            logger.info(f"Documents after truncation: {documents_str}")

        return documents_str

    def add_sources(
        self,
        matched_documents: pd.DataFrame,
    ):
        sources = (
            Source(
                source=dct["source"], title=dct["title"], url=dct["url"], question_similarity=dct["similarity"] * 100
            )
            for dct in matched_documents.to_dict(orient="records")
        )

        return sources

    def check_response_relevance(
        self, completion: str, engine: str, unk_embedding: np.array, unk_threshold: float
    ) -> bool:
        """Check to see if a response is relevant to the chatbot's knowledge or not.

        We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
        Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.

        set the unk_threshold to 0 to essentially turn off this feature.
        """
        response_embedding = self.get_embedding(
            completion,
            engine=engine,
        )
        score = cosine_similarity(response_embedding, unk_embedding)
        logger.info(f"UNK score: {score}")

        # Likely that the answer is meaningful, add the top sources
        return score < unk_threshold

    def process_input(self, user_input: str, formatter: ResponseFormatter = None) -> str:
        """
        Main function to process the input question and generate a formatted output.
        """

        logger.info(f"User Input:\n{user_input}")

        # We make sure there is always a newline at the end of the question to avoid completing the question.
        if not user_input.endswith("\n"):
            user_input += "\n"

        matched_documents = self.rank_documents(
            query=user_input,
            top_k=self.cfg.top_k,
            thresh=self.cfg.thresh,
            engine=self.cfg.embedding_model,
            source=self.cfg.source,
        )

        if len(matched_documents) == 0:
            response = Response(self.cfg.unknown_prompt)
            sources = tuple()
            return self.response_formatter(response, sources)

        # generate a completion
        documents: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
        response: Response = self.completer.generate_response(user_input, documents)
        logger.info(f"GPT Response:\n{response.text}")

        sources = self.add_sources(matched_documents)

        # check for relevance
        relevant = self.check_response_relevance(
            completion=response.text,
            engine=self.cfg.embedding_model,
            unk_embedding=self.unk_embedding,
            unk_threshold=self.cfg.unknown_threshold,
        )
        if not relevant:
            # answer generated was the chatbot saying it doesn't know how to answer
            # override completion with generic "I don't know"
            response = Response(text=self.cfg.unknown_prompt)
            sources = tuple()

        return self.response_formatter(response, sources)