File size: 9,530 Bytes
e112463
2118bb2
fa9ac7e
0e4a27a
e112463
 
 
 
2118bb2
e112463
 
71e7dd8
5b7d0e6
 
 
 
 
 
cb9fd37
e112463
 
 
2118bb2
 
 
 
 
 
 
 
 
 
05dabf4
1ea867d
 
 
 
 
 
 
 
6a09028
1ea867d
 
5b7d0e6
1ea867d
 
5b7d0e6
1ea867d
 
71e7dd8
1ea867d
 
 
6a09028
 
1ea867d
 
 
 
 
 
 
 
 
 
 
 
5b7d0e6
1ea867d
8b5fed9
 
5b7d0e6
1ea867d
 
fa9ac7e
1ea867d
fa9ac7e
 
 
 
5b7d0e6
 
 
 
 
 
e112463
fa9ac7e
c5f5dc3
 
71e7dd8
c5f5dc3
e112463
fa9ac7e
6a09028
fa9ac7e
6a09028
 
fa9ac7e
e112463
fa9ac7e
 
 
6a09028
 
 
fa9ac7e
 
 
 
f97aa81
fa9ac7e
 
 
 
71e7dd8
f97aa81
fa9ac7e
 
e112463
fa9ac7e
 
 
 
 
 
e112463
6a09028
 
1f22b14
fa9ac7e
6a09028
 
 
 
 
fa9ac7e
6a09028
 
e112463
6a09028
e112463
6a09028
 
 
 
 
 
 
fa9ac7e
6a09028
fa9ac7e
6a09028
 
fa9ac7e
0e4a27a
fa9ac7e
6a09028
fa9ac7e
0e4a27a
fa9ac7e
 
6a09028
0e4a27a
 
 
 
fa9ac7e
0e4a27a
 
 
6a09028
 
 
 
 
f9dc69d
 
6a09028
 
 
0e4a27a
 
 
 
 
 
 
 
 
 
5b7d0e6
0e4a27a
 
fa9ac7e
5b7d0e6
 
0e4a27a
e112463
0e4a27a
fa9ac7e
6a09028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7d0e6
fa9ac7e
 
 
 
 
 
1a12df2
 
 
 
6a09028
 
 
 
 
 
 
 
 
 
 
 
0e4a27a
fa9ac7e
5b7d0e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import logging
import os
from dataclasses import dataclass, field
from typing import Iterable

import numpy as np
import openai
import pandas as pd
import promptlayer
from openai.embeddings_utils import cosine_similarity, get_embedding

from buster.documents import get_documents_manager_from_extension
from buster.formatter import (
    Response,
    ResponseFormatter,
    Source,
    response_formatter_factory,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# Check if an API key exists for promptlayer, if it does, use it
promptlayer_api_key = os.environ.get("PROMPTLAYER_API_KEY")
if promptlayer_api_key:
    logger.info("Enabling prompt layer...")
    promptlayer.api_key = promptlayer_api_key

    # replace openai with the promptlayer wrapper
    openai = promptlayer.openai
    openai.api_key = os.environ.get("OPENAI_API_KEY")


@dataclass
class ChatbotConfig:
    """Configuration object for a chatbot.

    documents_csv: Path to the csv file containing the documents and their embeddings.
    embedding_model: OpenAI model to use to get embeddings.
    top_k: Max number of documents to retrieve, ordered by cosine similarity
    thresh: threshold for cosine similarity to be considered
    max_words: maximum number of words the retrieved documents can be. Will truncate otherwise.
    completion_kwargs: kwargs for the OpenAI.Completion() method
    separator: the separator to use, can be either "\n" or <p> depending on rendering.
    response_format: the type of format to render links with, e.g. slack or markdown
    unknown_prompt: Prompt to use to generate the "I don't know" embedding to compare to.
    text_before_prompt: Text to prompt GPT with before the user prompt, but after the documentation.
    reponse_footnote: Generic response to add the the chatbot's reply.
    """

    documents_file: str = "buster/data/document_embeddings.tar.gz"
    embedding_model: str = "text-embedding-ada-002"
    top_k: int = 3
    thresh: float = 0.7
    max_words: int = 3000
    unknown_threshold: float = 0.9  # set to 0 to deactivate

    completion_kwargs: dict = field(
        default_factory=lambda: {
            "engine": "text-davinci-003",
            "max_tokens": 200,
            "temperature": None,
            "top_p": None,
            "frequency_penalty": 1,
            "presence_penalty": 1,
        }
    )
    separator: str = "\n"
    response_format: str = "slack"
    unknown_prompt: str = "I Don't know how to answer your question."
    text_before_documents: str = "You are a chatbot answering questions.\n"
    text_before_prompt: str = "Answer the following question:\n"
    response_footnote: str = "I'm a bot 🤖 and not always perfect."


class Chatbot:
    def __init__(self, cfg: ChatbotConfig):
        # TODO: right now, the cfg is being passed as an omegaconf, is this what we want?
        self.cfg = cfg
        self._init_documents()
        self._init_unk_embedding()
        self._init_response_formatter()

    def _init_response_formatter(self):
        self.response_formatter = response_formatter_factory(
            format=self.cfg.response_format, response_footnote=self.cfg.response_footnote
        )

    def _init_documents(self):
        filepath = self.cfg.documents_file
        logger.info(f"loading embeddings from {filepath}...")
        self.documents = get_documents_manager_from_extension(filepath)(filepath)
        logger.info(f"embeddings loaded.")

    def _init_unk_embedding(self):
        logger.info("Generating UNK embedding...")
        self.unk_embedding = get_embedding(
            self.cfg.unknown_prompt,
            engine=self.cfg.embedding_model,
        )

    def rank_documents(
        self,
        query: str,
        top_k: float,
        thresh: float,
        engine: str,
    ) -> pd.DataFrame:
        """
        Compare the question to the series of documents and return the best matching documents.
        """

        query_embedding = get_embedding(
            query,
            engine=engine,
        )
        matched_documents = self.documents.retrieve(query_embedding, top_k)

        # log matched_documents to the console
        logger.info(f"matched documents before thresh: {matched_documents}")

        # filter out matched_documents using a threshold
        if thresh:
            matched_documents = matched_documents[matched_documents.similarity > thresh]
            logger.info(f"matched documents after thresh: {matched_documents}")

        return matched_documents

    def prepare_documents(self, matched_documents: pd.DataFrame, max_words: int) -> str:
        # gather the documents in one large plaintext variable
        documents_list = matched_documents.content.to_list()
        documents_str = " ".join(documents_list)

        # truncate the documents to fit
        # TODO: increase to actual token count
        word_count = len(documents_str.split(" "))
        if word_count > max_words:
            logger.info("truncating documents to fit...")
            documents_str = " ".join(documents_str.split(" ")[0:max_words])
            logger.info(f"Documents after truncation: {documents_str}")

        return documents_str

    def prepare_prompt(
        self,
        question: str,
        matched_documents: pd.DataFrame,
        text_before_prompt: str,
        text_before_documents: str,
    ) -> str:
        """
        Prepare the prompt with prompt engineering.
        """
        documents_str: str = self.prepare_documents(matched_documents, max_words=self.cfg.max_words)
        return text_before_documents + documents_str + text_before_prompt + question

    def get_gpt_response(self, **completion_kwargs) -> Response:
        # Call the API to generate a response
        logger.info(f"querying GPT...")
        try:
            response = openai.Completion.create(**completion_kwargs)
        except Exception as e:
            # log the error and return a generic response instead.
            logger.exception("Error connecting to OpenAI API. See traceback:")
            return Response("", True, "We're having trouble connecting to OpenAI right now... Try again soon!")

        text = response["choices"][0]["text"]
        return Response(text)

    def generate_response(
        self, prompt: str, matched_documents: pd.DataFrame, unknown_prompt: str
    ) -> tuple[Response, Iterable[Source]]:
        """
        Generate a response based on the retrieved documents.
        """
        if len(matched_documents) == 0:
            # No matching documents were retrieved, return
            sources = tuple()
            return Response(unknown_prompt), sources

        logger.info(f"Prompt:  {prompt}")
        response = self.get_gpt_response(prompt=prompt, **self.cfg.completion_kwargs)
        if response:
            logger.info(f"GPT Response:\n{response.text}")
            relevant = self.check_response_relevance(
                response=response.text,
                engine=self.cfg.embedding_model,
                unk_embedding=self.unk_embedding,
                unk_threshold=self.cfg.unknown_threshold,
            )
            if relevant:
                sources = (
                    Source(dct["source"], dct["url"], dct["similarity"])
                    for dct in matched_documents.to_dict(orient="records")
                )
            else:
                # Override the answer with a generic unknown prompt, without sources.
                response = Response(text=self.cfg.unknown_prompt)
                sources = tuple()

        return response, sources

    def check_response_relevance(
        self, response: str, engine: str, unk_embedding: np.array, unk_threshold: float
    ) -> bool:
        """Check to see if a response is relevant to the chatbot's knowledge or not.

        We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
        Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.

        set the unk_threshold to 0 to essentially turn off this feature.
        """
        response_embedding = get_embedding(
            response,
            engine=engine,
        )
        score = cosine_similarity(response_embedding, unk_embedding)
        logger.info(f"UNK score: {score}")

        # Likely that the answer is meaningful, add the top sources
        return score < unk_threshold

    def process_input(self, question: str, formatter: ResponseFormatter = None) -> str:
        """
        Main function to process the input question and generate a formatted output.
        """

        logger.info(f"User Question:\n{question}")

        # We make sure there is always a newline at the end of the question to avoid completing the question.
        if not question.endswith("\n"):
            question += "\n"

        matched_documents = self.rank_documents(
            query=question,
            top_k=self.cfg.top_k,
            thresh=self.cfg.thresh,
            engine=self.cfg.embedding_model,
        )
        prompt = self.prepare_prompt(
            question=question,
            matched_documents=matched_documents,
            text_before_prompt=self.cfg.text_before_prompt,
            text_before_documents=self.cfg.text_before_documents,
        )
        response, sources = self.generate_response(prompt, matched_documents, self.cfg.unknown_prompt)

        return self.response_formatter(response, sources)